aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--CODEOWNERS5
-rw-r--r--README.md21
-rw-r--r--RELEASE.md79
-rw-r--r--configure.py252
-rw-r--r--tensorflow/BUILD160
-rw-r--r--tensorflow/api_template.__init__.py15
-rw-r--r--tensorflow/c/BUILD1
-rw-r--r--tensorflow/c/c_api_experimental.cc49
-rw-r--r--tensorflow/c/c_api_experimental.h8
-rw-r--r--tensorflow/c/c_api_experimental_test.cc46
-rw-r--r--tensorflow/c/eager/BUILD5
-rwxr-xr-xtensorflow/c/eager/c_api.cc11
-rwxr-xr-xtensorflow/c/eager/c_api.h2
-rw-r--r--tensorflow/c/eager/tape.h130
-rw-r--r--tensorflow/c/python_api.cc7
-rw-r--r--tensorflow/c/python_api.h13
-rw-r--r--tensorflow/cc/BUILD28
-rw-r--r--tensorflow/compiler/aot/tests/BUILD15
-rw-r--r--tensorflow/compiler/aot/tests/make_test_graphs.py12
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt13
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc25
-rw-r--r--tensorflow/compiler/aot/tfcompile.bzl1
-rw-r--r--tensorflow/compiler/jit/BUILD36
-rw-r--r--tensorflow/compiler/jit/build_xla_launch_ops_pass.cc142
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass.cc182
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass.h (renamed from tensorflow/compiler/jit/build_xla_launch_ops_pass.h)10
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass_test.cc112
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op.cc2
-rw-r--r--tensorflow/compiler/jit/jit_compilation_pass_registration.cc4
-rw-r--r--tensorflow/compiler/jit/kernels/BUILD7
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc276
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.h87
-rw-r--r--tensorflow/compiler/jit/kernels/xla_ops.cc499
-rw-r--r--tensorflow/compiler/jit/kernels/xla_ops.h168
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc78
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc66
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc21
-rw-r--r--tensorflow/compiler/jit/ops/BUILD8
-rw-r--r--tensorflow/compiler/jit/ops/xla_ops.cc39
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass_test.cc11
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc7
-rw-r--r--tensorflow/compiler/jit/xla_cpu_device.cc10
-rw-r--r--tensorflow/compiler/jit/xla_device.cc12
-rw-r--r--tensorflow/compiler/jit/xla_device.h12
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.h16
-rw-r--r--tensorflow/compiler/jit/xla_gpu_device.cc11
-rw-r--r--tensorflow/compiler/jit/xla_interpreter_device.cc6
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc18
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.h17
-rw-r--r--tensorflow/compiler/tests/BUILD15
-rw-r--r--tensorflow/compiler/tests/argminmax_test.py4
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py25
-rw-r--r--tensorflow/compiler/tests/build_defs.bzl10
-rw-r--r--tensorflow/compiler/tests/dense_layer_test.py25
-rw-r--r--tensorflow/compiler/tests/fused_batchnorm_test.py40
-rw-r--r--tensorflow/compiler/tests/gather_test.py14
-rw-r--r--tensorflow/compiler/tests/image_ops_test.py55
-rw-r--r--tensorflow/compiler/tests/jit_test.py48
-rw-r--r--tensorflow/compiler/tests/lstm.py2
-rw-r--r--tensorflow/compiler/tests/quantized_ops_test.py48
-rw-r--r--tensorflow/compiler/tests/random_ops_test.py19
-rw-r--r--tensorflow/compiler/tests/reverse_sequence_op_test.py2
-rw-r--r--tensorflow/compiler/tests/sort_ops_test.py20
-rw-r--r--tensorflow/compiler/tests/stateless_random_ops_test.py7
-rw-r--r--tensorflow/compiler/tests/ternary_ops_test.py3
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py7
-rw-r--r--tensorflow/compiler/tests/xla_ops_test.py2
-rw-r--r--tensorflow/compiler/tests/xla_test.py19
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis.cc12
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc57
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD22
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc37
-rw-r--r--tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc509
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h69
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_ops.cc551
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_ops.cc9
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc76
-rw-r--r--tensorflow/compiler/tf2xla/kernels/shape_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/ops/xla_ops.cc7
-rw-r--r--tensorflow/compiler/tf2xla/shape_util.cc14
-rw-r--r--tensorflow/compiler/tf2xla/shape_util.h5
-rw-r--r--tensorflow/compiler/tf2xla/type_util.h8
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc6
-rw-r--r--tensorflow/compiler/tf2xla/xla_cpu_backend.cc15
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.cc24
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.h31
-rw-r--r--tensorflow/compiler/xla/BUILD1
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.cc12
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc8
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h24
-rw-r--r--tensorflow/compiler/xla/executable_run_options.cc10
-rw-r--r--tensorflow/compiler/xla/executable_run_options.h8
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc13
-rw-r--r--tensorflow/compiler/xla/literal.h16
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc6
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h3
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py19
-rw-r--r--tensorflow/compiler/xla/python/xla_client_test.py24
-rw-r--r--tensorflow/compiler/xla/rpc/BUILD13
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_service_main.cc21
-rw-r--r--tensorflow/compiler/xla/service/BUILD26
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc3
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.h2
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc17
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification.h2
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.h2
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding.h2
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization.h4
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.h2
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc26
-rw-r--r--tensorflow/compiler/xla/service/call_inliner.h2
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier.h2
-rw-r--r--tensorflow/compiler/xla/service/convolution_feature_group_converter.h2
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD17
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.cc122
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.h44
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc10
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc171
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc1
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc236
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h88
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc13
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/BUILD14
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc54
-rw-r--r--tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc18
-rw-r--r--tensorflow/compiler/xla/service/defuser.h2
-rw-r--r--tensorflow/compiler/xla/service/despecializer.cc2
-rw-r--r--tensorflow/compiler/xla/service/despecializer.h2
-rw-r--r--tensorflow/compiler/xla/service/dot_decomposer.h2
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc169
-rw-r--r--tensorflow/compiler/xla/service/flatten_call_graph.h2
-rw-r--r--tensorflow/compiler/xla/service/gather_expander.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD33
-rw-r--r--tensorflow/compiler/xla/service/gpu/backend_configs.proto14
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc43
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.h25
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc123
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc56
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc194
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h55
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc278
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h37
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h11
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc94
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc118
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.h56
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc31
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc30
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc35
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/BUILD60
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc283
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc205
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.h62
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc130
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto8
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc17
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h9
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc39
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.h30
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_isolator.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_remover.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_verifier.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc199
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc122
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h160
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc31
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h23
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h8
-rw-r--r--tensorflow/compiler/xla/service/hlo_liveness_analysis.cc35
-rw-r--r--tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc84
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_dce.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_dce.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.h9
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_test.cc64
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc90
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.h13
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_interface.h35
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.cc191
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.h38
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc259
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_subcomputation_unification.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_value.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h2
-rw-r--r--tensorflow/compiler/xla/service/implicit_broadcast_remover.h2
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/inliner.h2
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc29
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h13
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h2
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer_analysis.cc2
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.h2
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.cc4
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher.h762
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher_test.cc183
-rw-r--r--tensorflow/compiler/xla/service/platform_util.cc10
-rw-r--r--tensorflow/compiler/xla/service/reduce_precision_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover.h2
-rw-r--r--tensorflow/compiler/xla/service/scatter_expander.cc78
-rw-r--r--tensorflow/compiler/xla/service/scatter_expander.h2
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc5
-rw-r--r--tensorflow/compiler/xla/service/stream_pool.cc10
-rw-r--r--tensorflow/compiler/xla/service/stream_pool_test.cc34
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.h2
-rw-r--r--tensorflow/compiler/xla/service/tuple_simplifier.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_constant_sinking.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.cc5
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.h2
-rw-r--r--tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h2
-rw-r--r--tensorflow/compiler/xla/shape_util.cc7
-rw-r--r--tensorflow/compiler/xla/shape_util.h4
-rw-r--r--tensorflow/compiler/xla/tests/BUILD48
-rw-r--r--tensorflow/compiler/xla/tests/build_defs.bzl488
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.cc78
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.h63
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc158
-rw-r--r--tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc120
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc8
-rw-r--r--tensorflow/compiler/xla/tests/scatter_test.cc30
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc1
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc4
-rw-r--r--tensorflow/compiler/xla/xla.proto9
-rw-r--r--tensorflow/compiler/xrt/tests/BUILD6
-rw-r--r--tensorflow/contrib/BUILD61
-rw-r--r--tensorflow/contrib/__init__.py1
-rw-r--r--tensorflow/contrib/all_reduce/python/all_reduce_test.py2
-rw-r--r--tensorflow/contrib/autograph/README.md145
-rw-r--r--tensorflow/contrib/batching/python/ops/batch_ops_test.py29
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py18
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py29
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc1
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc1
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc1
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc1
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc1
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc1
-rw-r--r--tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py18
-rw-r--r--tensorflow/contrib/bigtable/python/ops/bigtable_api.py8
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/BUILD1
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py5
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py3
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py4
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc30
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py88
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py18
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py66
-rw-r--r--tensorflow/contrib/boosted_trees/python/utils/losses_test.py4
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt2
-rw-r--r--tensorflow/contrib/cmake/README.md2
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt5
-rw-r--r--tensorflow/contrib/cmake/python_protos.txt1
-rw-r--r--tensorflow/contrib/cmake/tf_tests.cmake1
-rw-r--r--tensorflow/contrib/coder/python/ops/coder_ops_test.py2
-rw-r--r--tensorflow/contrib/compiler/BUILD11
-rw-r--r--tensorflow/contrib/compiler/jit_test.py2
-rw-r--r--tensorflow/contrib/compiler/xla.py441
-rw-r--r--tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py4
-rw-r--r--tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py9
-rw-r--r--tensorflow/contrib/copy_graph/python/util/copy_elements.py6
-rw-r--r--tensorflow/contrib/copy_graph/python/util/copy_test.py4
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py25
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py2
-rw-r--r--tensorflow/contrib/data/BUILD38
-rw-r--r--tensorflow/contrib/data/__init__.py6
-rw-r--r--tensorflow/contrib/data/ops/dataset_ops.cc284
-rw-r--r--tensorflow/contrib/data/ops/indexed_dataset_ops.cc80
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD11
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py8
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py12
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/BUILD33
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py6
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py102
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py40
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py27
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py57
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py12
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py151
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py11
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py66
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py17
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py7
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD57
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py12
-rw-r--r--tensorflow/contrib/data/python/ops/error_ops.py9
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py63
-rw-r--r--tensorflow/contrib/data/python/ops/indexed_dataset_ops.py28
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py16
-rw-r--r--tensorflow/contrib/data/python/ops/optimization.py19
-rw-r--r--tensorflow/contrib/data/python/ops/parsing_ops.py4
-rw-r--r--tensorflow/contrib/data/python/ops/prefetching_ops.py228
-rw-r--r--tensorflow/contrib/data/python/ops/random_ops.py2
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py12
-rw-r--r--tensorflow/contrib/data/python/ops/scan_ops.py4
-rw-r--r--tensorflow/contrib/data/python/ops/shuffle_ops.py11
-rw-r--r--tensorflow/contrib/data/python/ops/sliding.py8
-rw-r--r--tensorflow/contrib/data/python/ops/stats_ops.py39
-rw-r--r--tensorflow/contrib/data/python/ops/threadpool.py13
-rw-r--r--tensorflow/contrib/data/python/ops/unique.py9
-rw-r--r--tensorflow/contrib/deprecated/summaries_test.py10
-rw-r--r--tensorflow/contrib/distribute/README.md3
-rw-r--r--tensorflow/contrib/distribute/__init__.py7
-rw-r--r--tensorflow/contrib/distribute/python/BUILD36
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py12
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py78
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py3
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_utils.py15
-rw-r--r--tensorflow/contrib/distribute/python/estimator_training_test.py248
-rw-r--r--tensorflow/contrib/distribute/python/examples/simple_estimator_example.py21
-rw-r--r--tensorflow/contrib/distribute/python/input_ops_test.py20
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py210
-rw-r--r--tensorflow/contrib/distribute/python/metrics_v1_test.py3
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py26
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py11
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py24
-rw-r--r--tensorflow/contrib/distribute/python/monitor.py1
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_test_base.py53
-rw-r--r--tensorflow/contrib/distribute/python/optimizer_v2_test.py8
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2.py228
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py90
-rw-r--r--tensorflow/contrib/distribute/python/step_fn.py7
-rw-r--r--tensorflow/contrib/distribute/python/step_fn_test.py1
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py19
-rw-r--r--tensorflow/contrib/distribute/python/values.py56
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py22
-rw-r--r--tensorflow/contrib/distributions/BUILD54
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py4
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py20
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/moving_stats_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/util/BUILD51
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py98
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py323
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py150
-rw-r--r--tensorflow/contrib/distributions/python/ops/autoregressive.py7
-rw-r--r--tensorflow/contrib/distributions/python/ops/batch_reshape.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/permute.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/reshape.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/cauchy.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/deterministic.py10
-rw-r--r--tensorflow/contrib/distributions/python/ops/gumbel.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/half_normal.py7
-rw-r--r--tensorflow/contrib/distributions/python/ops/independent.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/inverse_gamma.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/logistic.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/mixture.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/mixture_same_family.py7
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_diag.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_tril.py7
-rw-r--r--tensorflow/contrib/distributions/python/ops/poisson_lognormal.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/quantized_distribution.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/statistical_testing.py42
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_student_t.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/wishart.py18
-rw-r--r--tensorflow/contrib/eager/README.md7
-rw-r--r--tensorflow/contrib/eager/python/BUILD14
-rw-r--r--tensorflow/contrib/eager/python/examples/BUILD1
-rw-r--r--tensorflow/contrib/eager/python/examples/gan/BUILD1
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py182
-rw-r--r--tensorflow/contrib/eager/python/examples/linear_regression/BUILD1
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD1
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD1
-rw-r--r--tensorflow/contrib/eager/python/parameter_server.py289
-rw-r--r--tensorflow/contrib/eager/python/remote_test.py20
-rw-r--r--tensorflow/contrib/estimator/BUILD46
-rw-r--r--tensorflow/contrib/estimator/__init__.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees.py30
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py74
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py94
-rw-r--r--tensorflow/contrib/estimator/python/estimator/early_stopping.py39
-rw-r--r--tensorflow/contrib/estimator/python/estimator/hooks.py1
-rw-r--r--tensorflow/contrib/estimator/python/estimator/hooks_test.py2
-rw-r--r--tensorflow/contrib/factorization/BUILD1
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util_test.py2
-rw-r--r--tensorflow/contrib/framework/python/ops/variables_test.py28
-rw-r--r--tensorflow/contrib/fused_conv/BUILD43
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc4
-rw-r--r--tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py893
-rw-r--r--tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test_base.py945
-rw-r--r--tensorflow/contrib/gan/python/losses/python/losses_impl.py6
-rw-r--r--tensorflow/contrib/gan/python/namedtuples.py6
-rw-r--r--tensorflow/contrib/gan/python/train_test.py4
-rw-r--r--tensorflow/contrib/gdr/gdr_memory_manager.cc102
-rw-r--r--tensorflow/contrib/graph_editor/tests/transform_test.py2
-rw-r--r--tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py28
-rw-r--r--tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py2
-rw-r--r--tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py4
-rw-r--r--tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py8
-rw-r--r--tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py2
-rw-r--r--tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py6
-rw-r--r--tensorflow/contrib/kernel_methods/python/losses_test.py38
-rw-r--r--tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py12
-rw-r--r--tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py4
-rw-r--r--tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py6
-rw-r--r--tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py34
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py4
-rw-r--r--tensorflow/contrib/layers/python/layers/optimizers.py7
-rw-r--r--tensorflow/contrib/layers/python/layers/optimizers_test.py36
-rw-r--r--tensorflow/contrib/layers/python/layers/target_column.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py10
-rw-r--r--tensorflow/contrib/learn/python/learn/graph_actions_test.py14
-rw-r--r--tensorflow/contrib/learn/python/learn/monitors_test.py10
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py38
-rw-r--r--tensorflow/contrib/linalg/BUILD44
-rw-r--r--tensorflow/contrib/linalg/__init__.py58
-rw-r--r--tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py412
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_addition.py432
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py101
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py26
-rw-r--r--tensorflow/contrib/lite/README.md4
-rw-r--r--tensorflow/contrib/lite/build_def.bzl43
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h2
-rw-r--r--tensorflow/contrib/lite/c/c_api_internal.c25
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc94
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.h22
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc26
-rw-r--r--tensorflow/contrib/lite/delegates/flex/BUILD (renamed from tensorflow/contrib/lite/delegates/eager/BUILD)0
-rw-r--r--tensorflow/contrib/lite/delegates/flex/buffer_map.cc (renamed from tensorflow/contrib/lite/delegates/eager/buffer_map.cc)8
-rw-r--r--tensorflow/contrib/lite/delegates/flex/buffer_map.h (renamed from tensorflow/contrib/lite/delegates/eager/buffer_map.h)12
-rw-r--r--tensorflow/contrib/lite/delegates/flex/buffer_map_test.cc (renamed from tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc)6
-rw-r--r--tensorflow/contrib/lite/delegates/flex/delegate.cc (renamed from tensorflow/contrib/lite/delegates/eager/delegate.cc)34
-rw-r--r--tensorflow/contrib/lite/delegates/flex/delegate.h (renamed from tensorflow/contrib/lite/delegates/eager/delegate.h)26
-rw-r--r--tensorflow/contrib/lite/delegates/flex/delegate_data.cc (renamed from tensorflow/contrib/lite/delegates/eager/delegate_data.cc)6
-rw-r--r--tensorflow/contrib/lite/delegates/flex/delegate_data.h (renamed from tensorflow/contrib/lite/delegates/eager/delegate_data.h)16
-rw-r--r--tensorflow/contrib/lite/delegates/flex/delegate_data_test.cc (renamed from tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc)6
-rw-r--r--tensorflow/contrib/lite/delegates/flex/delegate_test.cc (renamed from tensorflow/contrib/lite/delegates/eager/delegate_test.cc)14
-rw-r--r--tensorflow/contrib/lite/delegates/flex/kernel.cc (renamed from tensorflow/contrib/lite/delegates/eager/kernel.cc)32
-rw-r--r--tensorflow/contrib/lite/delegates/flex/kernel.h (renamed from tensorflow/contrib/lite/delegates/eager/kernel.h)12
-rw-r--r--tensorflow/contrib/lite/delegates/flex/kernel_test.cc (renamed from tensorflow/contrib/lite/delegates/eager/kernel_test.cc)16
-rw-r--r--tensorflow/contrib/lite/delegates/flex/test_util.cc (renamed from tensorflow/contrib/lite/delegates/eager/test_util.cc)49
-rw-r--r--tensorflow/contrib/lite/delegates/flex/test_util.h (renamed from tensorflow/contrib/lite/delegates/eager/test_util.h)20
-rw-r--r--tensorflow/contrib/lite/delegates/flex/util.cc (renamed from tensorflow/contrib/lite/delegates/eager/util.cc)6
-rw-r--r--tensorflow/contrib/lite/delegates/flex/util.h (renamed from tensorflow/contrib/lite/delegates/eager/util.h)10
-rw-r--r--tensorflow/contrib/lite/delegates/flex/util_test.cc (renamed from tensorflow/contrib/lite/delegates/eager/util_test.cc)6
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc2
-rw-r--r--tensorflow/contrib/lite/examples/android/app/README.md37
-rw-r--r--tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h3
-rw-r--r--tensorflow/contrib/lite/experimental/c/BUILD12
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api.cc50
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api.h15
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_experimental.cc21
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_experimental.h27
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc25
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_internal.h16
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_test.cc31
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc2
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc2
-rw-r--r--tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc2
-rw-r--r--tensorflow/contrib/lite/g3doc/_book.yaml72
-rw-r--r--tensorflow/contrib/lite/g3doc/_index.yaml220
-rw-r--r--tensorflow/contrib/lite/g3doc/_project.yaml8
-rw-r--r--tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml6
-rw-r--r--tensorflow/contrib/lite/g3doc/devguide.md9
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.pngbin0 -> 10942 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.pngbin0 -> 578440 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.pngbin0 -> 7764 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.pngbin0 -> 16308 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.pngbin0 -> 20159 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.pngbin0 -> 35371 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.pngbin0 -> 12002 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.pngbin0 -> 25868 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.pngbin0 -> 7839 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.pngbin0 -> 27152 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.pngbin0 -> 17783 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.pngbin0 -> 17249 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/ios.md7
-rw-r--r--tensorflow/contrib/lite/g3doc/models.md17
-rw-r--r--tensorflow/contrib/lite/g3doc/overview.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/performance.md186
-rw-r--r--tensorflow/contrib/lite/g3doc/performance_benchmarks.md174
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md11
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/android_build.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/index.md2
-rw-r--r--tensorflow/contrib/lite/interpreter.cc9
-rw-r--r--tensorflow/contrib/lite/interpreter.h7
-rw-r--r--tensorflow/contrib/lite/java/demo/README.md4
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java26
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java29
-rw-r--r--tensorflow/contrib/lite/java/ovic/BUILD1
-rw-r--r--tensorflow/contrib/lite/java/ovic/README.md2
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/BUILD1
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java2
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java (renamed from tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java)4
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java2
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java104
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java48
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc15
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h9
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java48
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java9
-rw-r--r--tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java15
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD16
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc113
-rw-r--r--tensorflow/contrib/lite/kernels/audio_spectrogram.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc169
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons.cc51
-rw-r--r--tensorflow/contrib/lite/kernels/concatenation.cc39
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc171
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv.cc94
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv_test.cc162
-rw-r--r--tensorflow/contrib/lite/kernels/dequantize.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/detection_postprocess.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/detection_postprocess_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/div.cc27
-rw-r--r--tensorflow/contrib/lite/kernels/fake_quant.cc13
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected.cc66
-rw-r--r--tensorflow/contrib/lite/kernels/gather.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD14
-rw-r--r--tensorflow/contrib/lite/kernels/internal/compatibility.h32
-rw-r--r--tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc107
-rw-r--r--tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc147
-rw-r--r--tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc32
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h61
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h151
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h211
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h94
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h941
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h60
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h897
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h99
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h130
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h326
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h1067
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h1763
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/softmax.h179
-rw-r--r--tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc28
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor.h24
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h33
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_test.cc36
-rw-r--r--tensorflow/contrib/lite/kernels/internal/test_util.cc56
-rw-r--r--tensorflow/contrib/lite/kernels/internal/test_util.h11
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h51
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/layer_norm_lstm.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/log_softmax_test.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc48
-rw-r--r--tensorflow/contrib/lite/kernels/mfcc.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/mfcc_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/op_macros.h46
-rw-r--r--tensorflow/contrib/lite/kernels/pack.cc9
-rw-r--r--tensorflow/contrib/lite/kernels/reduce.cc52
-rw-r--r--tensorflow/contrib/lite/kernels/reduce_test.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/relu1_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/select.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/softmax_test.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_to_dense.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/split.cc27
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice.cc48
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/transpose.cc23
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv.cc21
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_test.cc24
-rw-r--r--tensorflow/contrib/lite/kernels/unpack.cc9
-rw-r--r--tensorflow/contrib/lite/kernels/zeros_like.cc73
-rw-r--r--tensorflow/contrib/lite/kernels/zeros_like_test.cc78
-rw-r--r--tensorflow/contrib/lite/model.cc20
-rw-r--r--tensorflow/contrib/lite/mutable_op_resolver.cc15
-rw-r--r--tensorflow/contrib/lite/mutable_op_resolver.h8
-rw-r--r--tensorflow/contrib/lite/mutable_op_resolver_test.cc34
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc6
-rw-r--r--tensorflow/contrib/lite/optional_debug_tools.cc10
-rw-r--r--tensorflow/contrib/lite/python/BUILD2
-rw-r--r--tensorflow/contrib/lite/python/convert.py21
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model.py12
-rw-r--r--tensorflow/contrib/lite/python/interpreter.py4
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc4
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h2
-rw-r--r--tensorflow/contrib/lite/python/lite.py94
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py177
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py12
-rw-r--r--tensorflow/contrib/lite/schema/BUILD4
-rw-r--r--tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc2
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs10
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h379
-rw-r--r--tensorflow/contrib/lite/testing/BUILD31
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py37
-rw-r--r--tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py249
-rw-r--r--tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py130
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_flags.h4
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_util.h2
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc8
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.h4
-rw-r--r--tensorflow/contrib/lite/toco/args.h4
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc28
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md2
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/python_api.md13
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc117
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc19
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc81
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc80
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc22
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.h2
-rw-r--r--tensorflow/contrib/lite/toco/model.h12
-rw-r--r--tensorflow/contrib/lite/toco/python/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc22
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.h4
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export_test.cc2
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc42
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.h6
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc2
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.cc24
-rw-r--r--tensorflow/contrib/lite/toco/toco_flags.proto16
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc21
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc1
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/BUILD12
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc14
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h12
-rw-r--r--tensorflow/contrib/lite/tools/make/Makefile1
-rwxr-xr-xtensorflow/contrib/lite/tools/make/download_dependencies.sh2
-rw-r--r--tensorflow/contrib/lite/tutorials/post_training_quant.ipynb4
-rw-r--r--tensorflow/contrib/lite/util.cc6
-rw-r--r--tensorflow/contrib/lite/util.h8
-rw-r--r--tensorflow/contrib/lite/util_test.cc16
-rw-r--r--tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py16
-rw-r--r--tensorflow/contrib/makefile/proto_text_pb_cc_files.txt1
-rw-r--r--tensorflow/contrib/makefile/proto_text_pb_h_files.txt1
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt3
-rw-r--r--tensorflow/contrib/makefile/tf_pb_text_files.txt1
-rw-r--r--tensorflow/contrib/makefile/tf_proto_files.txt1
-rw-r--r--tensorflow/contrib/metrics/python/kernel_tests/histogram_ops_test.py10
-rw-r--r--tensorflow/contrib/metrics/python/metrics/classification_test.py28
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py19
-rw-r--r--tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py5
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning.py3
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_test.py22
-rw-r--r--tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc6
-rw-r--r--tensorflow/contrib/mpi/mpi_rendezvous_mgr.h2
-rw-r--r--tensorflow/contrib/opt/BUILD22
-rw-r--r--tensorflow/contrib/opt/__init__.py5
-rw-r--r--tensorflow/contrib/opt/python/training/addsign_test.py12
-rw-r--r--tensorflow/contrib/opt/python/training/agn_optimizer.py262
-rw-r--r--tensorflow/contrib/opt/python/training/agn_optimizer_test.py281
-rw-r--r--tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer_test.py4
-rw-r--r--tensorflow/contrib/opt/python/training/external_optimizer_test.py22
-rw-r--r--tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py6
-rw-r--r--tensorflow/contrib/opt/python/training/model_average_optimizer_test.py3
-rw-r--r--tensorflow/contrib/opt/python/training/powersign_test.py12
-rw-r--r--tensorflow/contrib/optimizer_v2/adagrad.py2
-rw-r--r--tensorflow/contrib/predictor/BUILD3
-rw-r--r--tensorflow/contrib/quantization/README.md2
-rw-r--r--tensorflow/contrib/quantize/BUILD4
-rw-r--r--tensorflow/contrib/quantize/README.md2
-rw-r--r--tensorflow/contrib/quantize/python/common.py4
-rw-r--r--tensorflow/contrib/quantize/python/common_test.py59
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms.py97
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py130
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph_test.py37
-rw-r--r--tensorflow/contrib/quantize/python/quantize_parameterized_test.py282
-rw-r--r--tensorflow/contrib/rate/rate_test.py4
-rw-r--r--tensorflow/contrib/recurrent/python/ops/functional_rnn.py106
-rw-r--r--tensorflow/contrib/recurrent/python/ops/recurrent.py37
-rw-r--r--tensorflow/contrib/resampler/python/ops/resampler_ops_test.py8
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py4
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py32
-rw-r--r--tensorflow/contrib/saved_model/BUILD6
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py2
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py39
-rw-r--r--tensorflow/contrib/session_bundle/bundle_shim.cc9
-rw-r--r--tensorflow/contrib/session_bundle/bundle_shim.h6
-rw-r--r--tensorflow/contrib/session_bundle/bundle_shim_test.cc14
-rw-r--r--tensorflow/contrib/session_bundle/exporter_test.py6
-rw-r--r--tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py2
-rw-r--r--tensorflow/contrib/summary/summary_ops_graph_test.py28
-rw-r--r--tensorflow/contrib/tensor_forest/BUILD2
-rw-r--r--tensorflow/contrib/tensor_forest/client/eval_metrics_test.py8
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py13
-rw-r--r--tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py24
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest_test.py2
-rw-r--r--tensorflow/contrib/tensorboard/BUILD31
-rw-r--r--tensorflow/contrib/tensorboard/db/loader.cc6
-rw-r--r--tensorflow/contrib/tensorboard/plugins/__init__.py2
-rw-r--r--tensorflow/contrib/tensorboard/plugins/trace/trace.py167
-rw-r--r--tensorflow/contrib/tensorboard/plugins/trace/trace_info.proto60
-rw-r--r--tensorflow/contrib/tensorboard/plugins/trace/trace_test.py95
-rw-r--r--tensorflow/contrib/tensorrt/BUILD21
-rw-r--r--tensorflow/contrib/tensorrt/README.md2
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc14
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc19
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc13
-rw-r--r--tensorflow/contrib/tensorrt/python/trt_convert_test.py6
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_allocator.cc18
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_allocator.h2
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc21
-rw-r--r--tensorflow/contrib/tensorrt/test/base_test.py6
-rw-r--r--tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py8
-rw-r--r--tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py7
-rw-r--r--tensorflow/contrib/text/python/ops/skip_gram_ops_test.py32
-rw-r--r--tensorflow/contrib/timeseries/examples/BUILD1
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/ar_model.py65
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators.py157
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators_test.py35
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py2
-rw-r--r--tensorflow/contrib/tpu/BUILD36
-rw-r--r--tensorflow/contrib/tpu/__init__.py1
-rw-r--r--tensorflow/contrib/tpu/ops/cross_replica_ops.cc20
-rw-r--r--tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc628
-rw-r--r--tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc3
-rw-r--r--tensorflow/contrib/tpu/profiler/op_profile.proto8
-rw-r--r--tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py7
-rw-r--r--tensorflow/contrib/tpu/profiler/pip_package/setup.py2
-rw-r--r--tensorflow/contrib/tpu/profiler/version.h2
-rw-r--r--tensorflow/contrib/tpu/proto/BUILD18
-rw-r--r--tensorflow/contrib/tpu/proto/tpu_embedding_config.proto66
-rw-r--r--tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto95
-rw-r--r--tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto75
-rw-r--r--tensorflow/contrib/tpu/python/ops/tpu_ops.py27
-rw-r--r--tensorflow/contrib/tpu/python/tpu/async_checkpoint.py202
-rw-r--r--tensorflow/contrib/tpu/python/tpu/device_assignment.py158
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py681
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py53
-rw-r--r--tensorflow/contrib/tpu/python/tpu/session_support.py58
-rw-r--r--tensorflow/contrib/tpu/python/tpu/topology.py15
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py31
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_config.py7
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_config_test.py2
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_context.py30
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py4
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_feed.py22
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_function.py8
-rw-r--r--tensorflow/contrib/tpu/utils/BUILD30
-rw-r--r--tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.cc255
-rw-r--r--tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.h90
-rw-r--r--tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.cc98
-rw-r--r--tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.h38
-rw-r--r--tensorflow/contrib/training/python/training/device_setter_test.py8
-rw-r--r--tensorflow/contrib/training/python/training/tensor_queue_dataset.py4
-rw-r--r--tensorflow/contrib/verbs/rdma_mgr.cc81
-rw-r--r--tensorflow/contrib/verbs/rdma_mgr.h1
-rw-r--r--tensorflow/contrib/verbs/verbs_server_lib.cc5
-rw-r--r--tensorflow/core/BUILD109
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt21
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt58
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt25
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt13
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt13
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt35
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExtractVolumePatches.pbtxt49
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt45
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MultiDeviceIterator.pbtxt43
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorFromStringHandle.pbtxt29
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorGetNextFromShard.pbtxt41
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorInit.pbtxt30
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorToStringHandle.pbtxt17
-rw-r--r--tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt19
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ReduceDataset.pbtxt26
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt38
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnicodeScript.pbtxt28
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt45
-rw-r--r--tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt23
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Tile.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnicodeScript.pbtxt6
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt6
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt6
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.cc21
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.h16
-rw-r--r--tensorflow/core/common_runtime/constant_folding.cc39
-rw-r--r--tensorflow/core/common_runtime/copy_tensor.cc89
-rw-r--r--tensorflow/core/common_runtime/device.h10
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc8
-rw-r--r--tensorflow/core/common_runtime/direct_session.h20
-rw-r--r--tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc19
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc16
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc41
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.cc16
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.h1
-rw-r--r--tensorflow/core/common_runtime/executor.cc139
-rw-r--r--tensorflow/core/common_runtime/executor.h6
-rw-r--r--tensorflow/core/common_runtime/gpu/cuda_host_allocator.h12
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc50
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h45
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc146
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc15
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h12
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc30
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h22
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc80
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc293
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.h36
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device_test.cc19
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id.h32
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id_manager.cc38
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id_manager.h12
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc32
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id_utils.h37
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_process_state.cc175
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_process_state.h58
-rw-r--r--tensorflow/core/common_runtime/gpu/pool_allocator_test.cc68
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.cc4
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.h5
-rw-r--r--tensorflow/core/common_runtime/local_device.cc2
-rw-r--r--tensorflow/core/common_runtime/local_device.h3
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator.h51
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc4
-rw-r--r--tensorflow/core/common_runtime/parallel_concat_optimizer.cc6
-rw-r--r--tensorflow/core/common_runtime/pool_allocator.cc45
-rw-r--r--tensorflow/core/common_runtime/pool_allocator.h27
-rw-r--r--tensorflow/core/common_runtime/process_state.cc71
-rw-r--r--tensorflow/core/common_runtime/process_state.h15
-rw-r--r--tensorflow/core/common_runtime/renamed_device.h16
-rw-r--r--tensorflow/core/common_runtime/ring_reducer.cc75
-rw-r--r--tensorflow/core/common_runtime/ring_reducer_test.cc83
-rw-r--r--tensorflow/core/common_runtime/session_ref.cc170
-rw-r--r--tensorflow/core/common_runtime/step_stats_collector.cc182
-rw-r--r--tensorflow/core/common_runtime/step_stats_collector.h137
-rw-r--r--tensorflow/core/common_runtime/threadpool_device.cc5
-rw-r--r--tensorflow/core/common_runtime/tracing_device.h60
-rw-r--r--tensorflow/core/common_runtime/visitable_allocator.h79
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc5
-rw-r--r--tensorflow/core/example/feature_util.h5
-rw-r--r--tensorflow/core/framework/allocator.cc20
-rw-r--r--tensorflow/core/framework/allocator.h28
-rw-r--r--tensorflow/core/framework/cancellation.cc10
-rw-r--r--tensorflow/core/framework/cancellation.h9
-rw-r--r--tensorflow/core/framework/cancellation_test.cc52
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc107
-rw-r--r--tensorflow/core/framework/common_shape_fns.h3
-rw-r--r--tensorflow/core/framework/dataset.cc1
-rw-r--r--tensorflow/core/framework/dataset.h92
-rw-r--r--tensorflow/core/framework/device_base.h13
-rw-r--r--tensorflow/core/framework/function_testlib.cc17
-rw-r--r--tensorflow/core/framework/model.cc301
-rw-r--r--tensorflow/core/framework/model.h632
-rw-r--r--tensorflow/core/framework/model.proto30
-rw-r--r--tensorflow/core/framework/node_def_util.cc20
-rw-r--r--tensorflow/core/framework/node_def_util.h8
-rw-r--r--tensorflow/core/framework/node_def_util_test.cc42
-rw-r--r--tensorflow/core/framework/op_kernel.cc9
-rw-r--r--tensorflow/core/framework/op_kernel.h31
-rw-r--r--tensorflow/core/framework/resource_mgr.cc9
-rw-r--r--tensorflow/core/framework/resource_mgr.h106
-rw-r--r--tensorflow/core/framework/tensor.cc134
-rw-r--r--tensorflow/core/framework/tensor.h23
-rw-r--r--tensorflow/core/framework/tensor_test.cc96
-rw-r--r--tensorflow/core/graph/graph_constructor.cc4
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc9
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc23
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc24
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc2
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass_test.cc4
-rw-r--r--tensorflow/core/graph/testlib.h2
-rw-r--r--tensorflow/core/grappler/clusters/cluster.cc1
-rw-r--r--tensorflow/core/grappler/clusters/single_machine.cc6
-rw-r--r--tensorflow/core/grappler/clusters/utils.cc13
-rw-r--r--tensorflow/core/grappler/clusters/utils.h2
-rw-r--r--tensorflow/core/grappler/clusters/utils_test.cc22
-rw-r--r--tensorflow/core/grappler/costs/utils.cc8
-rw-r--r--tensorflow/core/grappler/graph_view.cc35
-rw-r--r--tensorflow/core/grappler/graph_view.h10
-rw-r--r--tensorflow/core/grappler/graph_view_test.cc83
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.cc8
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.h2
-rw-r--r--tensorflow/core/grappler/grappler_item_builder_test.cc23
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD75
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc162
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc61
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc30
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc23
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD144
-rw-r--r--tensorflow/core/grappler/optimizers/data/filter_fusion.cc13
-rw-r--r--tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc11
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_utils.cc176
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_utils.h108
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_utils_test.cc164
-rw-r--r--tensorflow/core/grappler/optimizers/data/fusion_utils.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc5
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_test_utils.cc49
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_test_utils.h36
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc110
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h52
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils_test.cc94
-rw-r--r--tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc289
-rw-r--r--tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h55
-rw-r--r--tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc84
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc5
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc14
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc21
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_fusion.cc30
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_fusion_test.cc10
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc13
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc49
-rw-r--r--tensorflow/core/grappler/optimizers/data/noop_elimination.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc43
-rw-r--r--tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/BUILD69
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc54
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc61
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h49
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc47
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h75
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc50
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc292
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.h90
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc600
-rw-r--r--tensorflow/core/grappler/optimizers/debug_stripper.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/debug_stripper_test.cc29
-rw-r--r--tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc48
-rw-r--r--tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc5
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc58
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.h4
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc30
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc264
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h62
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc194
-rw-r--r--tensorflow/core/grappler/optimizers/remapper.cc10
-rw-r--r--tensorflow/core/grappler/optimizers/shape_optimizer.cc15
-rw-r--r--tensorflow/core/grappler/utils.cc13
-rw-r--r--tensorflow/core/grappler/utils.h106
-rw-r--r--tensorflow/core/grappler/utils/BUILD29
-rw-r--r--tensorflow/core/grappler/utils/grappler_test.cc9
-rw-r--r--tensorflow/core/grappler/utils/symbolic_shapes.cc (renamed from tensorflow/core/grappler/optimizers/symbolic_shapes.cc)2
-rw-r--r--tensorflow/core/grappler/utils/symbolic_shapes.h (renamed from tensorflow/core/grappler/optimizers/symbolic_shapes.h)6
-rw-r--r--tensorflow/core/grappler/utils/symbolic_shapes_test.cc (renamed from tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc)2
-rw-r--r--tensorflow/core/grappler/utils_test.cc55
-rw-r--r--tensorflow/core/kernels/BUILD150
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_complex.cc10
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_real.cc9
-rw-r--r--tensorflow/core/kernels/batching_util/BUILD20
-rw-r--r--tensorflow/core/kernels/bincount_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/boosted_trees/boosted_trees.proto13
-rw-r--r--tensorflow/core/kernels/boosted_trees/prediction_ops.cc38
-rw-r--r--tensorflow/core/kernels/boosted_trees/resources.cc26
-rw-r--r--tensorflow/core/kernels/collective_ops.cc21
-rw-r--r--tensorflow/core/kernels/conv_ops.cc333
-rw-r--r--tensorflow/core/kernels/conv_ops.h44
-rw-r--r--tensorflow/core/kernels/conv_ops_3d.cc14
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu.h6
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_op_xdivy.cc38
-rw-r--r--tensorflow/core/kernels/cwise_op_xlogy.cc41
-rw-r--r--tensorflow/core/kernels/cwise_ops.h45
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.cc4
-rw-r--r--tensorflow/core/kernels/data/BUILD16
-rw-r--r--tensorflow/core/kernels/data/batch_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/captured_function.cc144
-rw-r--r--tensorflow/core/kernels/data/captured_function.h22
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.cc37
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.h10
-rw-r--r--tensorflow/core/kernels/data/experimental/BUILD (renamed from tensorflow/contrib/data/kernels/BUILD)90
-rw-r--r--tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/assert_next_dataset_op.cc)5
-rw-r--r--tensorflow/core/kernels/data/experimental/csv_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/csv_dataset_op.cc)3
-rw-r--r--tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc)5
-rw-r--r--tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc (renamed from tensorflow/contrib/data/kernels/identity_indexed_dataset.cc)7
-rw-r--r--tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc)6
-rw-r--r--tensorflow/core/kernels/data/experimental/indexed_dataset.cc (renamed from tensorflow/contrib/data/kernels/indexed_dataset.cc)14
-rw-r--r--tensorflow/core/kernels/data/experimental/indexed_dataset.h (renamed from tensorflow/contrib/data/kernels/indexed_dataset.h)6
-rw-r--r--tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/lmdb_dataset_op.cc)3
-rw-r--r--tensorflow/core/kernels/data/experimental/prefetching_kernels.cc482
-rw-r--r--tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/threadpool_dataset_op.cc)7
-rw-r--r--tensorflow/core/kernels/data/experimental/unique_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/unique_dataset_op.cc)7
-rw-r--r--tensorflow/core/kernels/data/filter_dataset_op.cc61
-rw-r--r--tensorflow/core/kernels/data/flat_map_dataset_op.cc13
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.cc53
-rw-r--r--tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/group_by_window_dataset_op.cc55
-rw-r--r--tensorflow/core/kernels/data/interleave_dataset_op.cc12
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc148
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc88
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc14
-rw-r--r--tensorflow/core/kernels/data/model_dataset_op.cc60
-rw-r--r--tensorflow/core/kernels/data/multi_device_iterator_ops.cc (renamed from tensorflow/contrib/data/kernels/prefetching_kernels.cc)551
-rw-r--r--tensorflow/core/kernels/data/optional_ops.cc15
-rw-r--r--tensorflow/core/kernels/data/padded_batch_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc216
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc16
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc65
-rw-r--r--tensorflow/core/kernels/data/parse_example_dataset_op.cc7
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.cc28
-rw-r--r--tensorflow/core/kernels/data/scan_dataset_op.cc23
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc24
-rw-r--r--tensorflow/core/kernels/data/tensor_dataset_op.cc6
-rw-r--r--tensorflow/core/kernels/data/window_dataset_op.cc215
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc8
-rw-r--r--tensorflow/core/kernels/eigen_cuboid_convolution.h432
-rw-r--r--tensorflow/core/kernels/eigen_spatial_convolutions.h342
-rw-r--r--tensorflow/core/kernels/extract_volume_patches_op.cc197
-rw-r--r--tensorflow/core/kernels/extract_volume_patches_op.h58
-rw-r--r--tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc38
-rw-r--r--tensorflow/core/kernels/fuzzing/BUILD2
-rw-r--r--tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc45
-rw-r--r--tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc32
-rw-r--r--tensorflow/core/kernels/gather_nd_op_cpu_impl.h6
-rw-r--r--tensorflow/core/kernels/histogram_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/logging_ops.cc54
-rw-r--r--tensorflow/core/kernels/logging_ops_test.cc22
-rw-r--r--tensorflow/core/kernels/matmul_op.cc8
-rw-r--r--tensorflow/core/kernels/mkl_batch_matmul_op.cc2
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops_test.cc407
-rw-r--r--tensorflow/core/kernels/mkl_matmul_op.cc6
-rw-r--r--tensorflow/core/kernels/mkl_slice_op.cc358
-rw-r--r--tensorflow/core/kernels/multinomial_op.cc2
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc12
-rw-r--r--tensorflow/core/kernels/queue_base.h4
-rw-r--r--tensorflow/core/kernels/random_op.cc10
-rw-r--r--tensorflow/core/kernels/reduction_gpu_kernels.cu.h10
-rw-r--r--tensorflow/core/kernels/reduction_ops_max.cc2
-rw-r--r--tensorflow/core/kernels/reduction_ops_sum.cc10
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc60
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.h10
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.cc1
-rw-r--r--tensorflow/core/kernels/searchsorted_op.cc249
-rw-r--r--tensorflow/core/kernels/searchsorted_op.h52
-rw-r--r--tensorflow/core/kernels/searchsorted_op_gpu.cu.cc126
-rw-r--r--tensorflow/core/kernels/slice_op.cc14
-rw-r--r--tensorflow/core/kernels/split_lib_gpu.cu.cc1
-rw-r--r--tensorflow/core/kernels/strided_slice_op.cc1
-rw-r--r--tensorflow/core/kernels/string_format_op.cc65
-rw-r--r--tensorflow/core/kernels/string_format_op_test.cc66
-rw-r--r--tensorflow/core/kernels/string_length_op.cc23
-rw-r--r--tensorflow/core/kernels/string_util.cc63
-rw-r--r--tensorflow/core/kernels/string_util.h45
-rw-r--r--tensorflow/core/kernels/tensor_array.cc3
-rw-r--r--tensorflow/core/kernels/tensor_array.h3
-rw-r--r--tensorflow/core/kernels/tensor_array_ops.cc3
-rw-r--r--tensorflow/core/kernels/topk_op_gpu.cu.cc6
-rw-r--r--tensorflow/core/kernels/training_op_helpers.cc44
-rw-r--r--tensorflow/core/kernels/training_op_helpers.h37
-rw-r--r--tensorflow/core/kernels/training_ops.cc8
-rw-r--r--tensorflow/core/kernels/transpose_op.cc10
-rw-r--r--tensorflow/core/kernels/unicode_script_op.cc53
-rw-r--r--tensorflow/core/kernels/where_op_gpu.cu.h8
-rw-r--r--tensorflow/core/lib/core/threadpool.cc49
-rw-r--r--tensorflow/core/lib/core/threadpool.h14
-rw-r--r--tensorflow/core/lib/core/threadpool_test.cc61
-rw-r--r--tensorflow/core/lib/io/record_reader.cc53
-rw-r--r--tensorflow/core/lib/io/record_reader.h25
-rw-r--r--tensorflow/core/lib/io/record_reader_writer_test.cc7
-rw-r--r--tensorflow/core/lib/jpeg/jpeg_mem.cc6
-rw-r--r--tensorflow/core/ops/array_ops.cc260
-rw-r--r--tensorflow/core/ops/boosted_trees_ops.cc2
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt1317
-rw-r--r--tensorflow/core/ops/cudnn_rnn_ops.cc9
-rw-r--r--tensorflow/core/ops/cudnn_rnn_ops_test.cc11
-rw-r--r--tensorflow/core/ops/dataset_ops.cc60
-rw-r--r--tensorflow/core/ops/experimental_dataset_ops.cc207
-rw-r--r--tensorflow/core/ops/logging_ops.cc19
-rw-r--r--tensorflow/core/ops/math_grad.cc34
-rw-r--r--tensorflow/core/ops/math_grad_test.cc40
-rw-r--r--tensorflow/core/ops/math_ops.cc14
-rw-r--r--tensorflow/core/ops/nn_ops.cc18
-rw-r--r--tensorflow/core/ops/ops.pbtxt940
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc72
-rw-r--r--tensorflow/core/ops/string_ops.cc33
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc3
-rw-r--r--tensorflow/core/platform/default/build_config.bzl45
-rw-r--r--tensorflow/core/platform/default/build_config_root.bzl86
-rw-r--r--tensorflow/core/platform/default/cord.h5
-rw-r--r--tensorflow/core/platform/default/device_tracer.cc7
-rw-r--r--tensorflow/core/platform/file_system.h3
-rw-r--r--tensorflow/core/platform/tracing.h5
-rw-r--r--tensorflow/core/protobuf/config.proto2
-rw-r--r--tensorflow/core/protobuf/replay_log.proto47
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto6
-rw-r--r--tensorflow/core/public/version.h4
-rw-r--r--tensorflow/core/util/cuda_kernel_helper.h31
-rw-r--r--tensorflow/core/util/mkl_util.h16
-rw-r--r--tensorflow/core/util/port.cc4
-rw-r--r--tensorflow/core/util/sparse/sparse_tensor.h14
-rw-r--r--tensorflow/core/util/tensor_bundle/BUILD1
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle.cc52
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc64
-rw-r--r--tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README3
-rw-r--r--tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001bin0 -> 1080 bytes
-rw-r--r--tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.indexbin0 -> 211 bytes
-rw-r--r--tensorflow/core/util/work_sharder.cc2
-rw-r--r--tensorflow/core/util/work_sharder.h3
-rw-r--r--tensorflow/examples/android/BUILD1
-rw-r--r--tensorflow/examples/autograph/integration_tests/errors_test.py4
-rw-r--r--tensorflow/examples/learn/text_classification_character_cnn.py2
-rw-r--r--tensorflow/examples/tutorials/mnist/BUILD12
-rw-r--r--tensorflow/go/README.md6
-rw-r--r--tensorflow/go/op/wrappers.go6969
-rw-r--r--tensorflow/java/README.md7
-rw-r--r--tensorflow/java/maven/libtensorflow/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml2
-rw-r--r--tensorflow/java/maven/pom.xml2
-rw-r--r--tensorflow/java/maven/proto/pom.xml2
-rw-r--r--tensorflow/java/maven/spark-tensorflow-connector/pom.xml2
-rw-r--r--tensorflow/java/maven/tensorflow-hadoop/pom.xml2
-rw-r--r--tensorflow/java/maven/tensorflow/pom.xml2
-rw-r--r--tensorflow/python/BUILD53
-rw-r--r--tensorflow/python/autograph/README.md2
-rw-r--r--tensorflow/python/autograph/__init__.py2
-rw-r--r--tensorflow/python/autograph/converters/builtin_functions.py9
-rw-r--r--tensorflow/python/autograph/converters/builtin_functions_test.py16
-rw-r--r--tensorflow/python/autograph/converters/call_trees.py11
-rw-r--r--tensorflow/python/autograph/converters/return_statements.py14
-rw-r--r--tensorflow/python/autograph/converters/return_statements_test.py12
-rw-r--r--tensorflow/python/autograph/core/converter.py8
-rw-r--r--tensorflow/python/autograph/core/converter_testing.py12
-rw-r--r--tensorflow/python/autograph/core/errors.py1
-rw-r--r--tensorflow/python/autograph/core/errors_test.py6
-rw-r--r--tensorflow/python/autograph/impl/api.py92
-rw-r--r--tensorflow/python/autograph/impl/api_test.py50
-rw-r--r--tensorflow/python/autograph/impl/conversion.py1
-rw-r--r--tensorflow/python/autograph/lang/special_functions_test.py4
-rw-r--r--tensorflow/python/autograph/operators/py_builtins.py7
-rw-r--r--tensorflow/python/autograph/operators/py_builtins_test.py23
-rw-r--r--tensorflow/python/autograph/operators/slices_test.py4
-rw-r--r--tensorflow/python/autograph/pyct/cfg.py13
-rw-r--r--tensorflow/python/autograph/pyct/compiler.py13
-rw-r--r--tensorflow/python/autograph/pyct/origin_info.py2
-rw-r--r--tensorflow/python/autograph/pyct/origin_info_test.py59
-rw-r--r--tensorflow/python/autograph/pyct/parser.py15
-rw-r--r--tensorflow/python/autograph/pyct/parser_test.py16
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/activity.py6
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/live_values.py12
-rw-r--r--tensorflow/python/autograph/pyct/templates.py13
-rw-r--r--tensorflow/python/autograph/pyct/templates_test.py24
-rw-r--r--tensorflow/python/client/session.py46
-rw-r--r--tensorflow/python/client/session_ref.cc525
-rw-r--r--tensorflow/python/client/session_ref.h (renamed from tensorflow/core/common_runtime/session_ref.h)15
-rw-r--r--tensorflow/python/client/session_test.py100
-rw-r--r--tensorflow/python/client/tf_session.i4
-rw-r--r--tensorflow/python/client/tf_session_helper.cc2
-rw-r--r--tensorflow/python/client/timeline.py3
-rw-r--r--tensorflow/python/client/timeline_test.py4
-rw-r--r--tensorflow/python/compat/compat.py2
-rw-r--r--tensorflow/python/data/BUILD1
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD291
-rw-r--r--tensorflow/python/data/kernel_tests/batch_dataset_op_test.py10
-rw-r--r--tensorflow/python/data/kernel_tests/cache_dataset_op_test.py5
-rw-r--r--tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py8
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_ops_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/filter_dataset_op_test.py8
-rw-r--r--tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/inputs_test.py149
-rw-r--r--tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py9
-rw-r--r--tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/map_dataset_op_test.py32
-rw-r--r--tensorflow/python/data/kernel_tests/multi_device_iterator_test.py191
-rw-r--r--tensorflow/python/data/kernel_tests/optional_ops_test.py177
-rw-r--r--tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/range_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py7
-rw-r--r--tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py124
-rw-r--r--tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/shard_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/test_base.py (renamed from tensorflow/contrib/data/python/ops/contrib_op_loader.py)15
-rw-r--r--tensorflow/python/data/kernel_tests/window_dataset_op_test.py291
-rw-r--r--tensorflow/python/data/kernel_tests/zip_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/ops/BUILD19
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py321
-rw-r--r--tensorflow/python/data/ops/iterator_ops.py13
-rw-r--r--tensorflow/python/data/ops/multi_device_iterator_ops.py231
-rw-r--r--tensorflow/python/data/ops/optional_ops.py150
-rw-r--r--tensorflow/python/data/ops/readers.py12
-rw-r--r--tensorflow/python/data/util/structure.py131
-rw-r--r--tensorflow/python/data/util/structure_test.py36
-rw-r--r--tensorflow/python/debug/BUILD1
-rw-r--r--tensorflow/python/debug/cli/analyzer_cli_test.py23
-rw-r--r--tensorflow/python/debug/cli/stepper_cli_test.py4
-rw-r--r--tensorflow/python/debug/lib/debug_graph_reconstruction_test.py3
-rw-r--r--tensorflow/python/debug/lib/debug_utils_test.py4
-rw-r--r--tensorflow/python/debug/lib/dist_session_debug_grpc_test.py4
-rw-r--r--tensorflow/python/debug/lib/grpc_large_data_test.py12
-rw-r--r--tensorflow/python/debug/lib/session_debug_file_test.py4
-rw-r--r--tensorflow/python/debug/lib/session_debug_grpc_test.py48
-rw-r--r--tensorflow/python/debug/lib/session_debug_testlib.py90
-rw-r--r--tensorflow/python/debug/lib/stepper_test.py14
-rw-r--r--tensorflow/python/debug/wrappers/dumping_wrapper_test.py2
-rw-r--r--tensorflow/python/debug/wrappers/local_cli_wrapper_test.py14
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py4
-rw-r--r--tensorflow/python/distribute/estimator_training.py23
-rw-r--r--tensorflow/python/eager/BUILD35
-rw-r--r--tensorflow/python/eager/backprop.py43
-rw-r--r--tensorflow/python/eager/backprop_test.py12
-rw-r--r--tensorflow/python/eager/def_function.py235
-rw-r--r--tensorflow/python/eager/def_function_test.py87
-rw-r--r--tensorflow/python/eager/function.py233
-rw-r--r--tensorflow/python/eager/function_test.py218
-rw-r--r--tensorflow/python/eager/imperative_grad.py5
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc41
-rw-r--r--tensorflow/python/eager/pywrap_tensor.h5
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc464
-rw-r--r--tensorflow/python/estimator/BUILD32
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py383
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py704
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_utils.py80
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_utils_test.py187
-rw-r--r--tensorflow/python/estimator/canned/dnn.py188
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined.py7
-rw-r--r--tensorflow/python/estimator/canned/dnn_test.py161
-rw-r--r--tensorflow/python/estimator/canned/dnn_testing_utils.py116
-rw-r--r--tensorflow/python/estimator/estimator.py57
-rw-r--r--tensorflow/python/estimator/estimator_test.py56
-rw-r--r--tensorflow/python/estimator/export/export_test.py2
-rw-r--r--tensorflow/python/estimator/keras.py39
-rw-r--r--tensorflow/python/estimator/keras_test.py28
-rw-r--r--tensorflow/python/estimator/model_fn.py6
-rw-r--r--tensorflow/python/feature_column/BUILD2
-rw-r--r--tensorflow/python/feature_column/feature_column.py35
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py12
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py590
-rw-r--r--tensorflow/python/feature_column/feature_column_v2_test.py1878
-rw-r--r--tensorflow/python/framework/function.py26
-rw-r--r--tensorflow/python/framework/function_test.py29
-rw-r--r--tensorflow/python/framework/graph_util_test.py8
-rw-r--r--tensorflow/python/framework/load_library.py65
-rw-r--r--tensorflow/python/framework/ops.py4
-rw-r--r--tensorflow/python/framework/ops_test.py12
-rw-r--r--tensorflow/python/framework/subscribe_test.py4
-rw-r--r--tensorflow/python/framework/test_util.py180
-rw-r--r--tensorflow/python/framework/test_util_test.py8
-rw-r--r--tensorflow/python/grappler/item_test.py2
-rw-r--r--tensorflow/python/grappler/memory_optimizer_test.py10
-rw-r--r--tensorflow/python/grappler/tf_optimizer_test.py2
-rwxr-xr-xtensorflow/python/keras/BUILD5
-rw-r--r--tensorflow/python/keras/applications/__init__.py3
-rw-r--r--tensorflow/python/keras/backend.py66
-rw-r--r--tensorflow/python/keras/callbacks.py16
-rw-r--r--tensorflow/python/keras/callbacks_test.py40
-rw-r--r--tensorflow/python/keras/engine/base_layer.py161
-rw-r--r--tensorflow/python/keras/engine/distributed_training_utils.py16
-rw-r--r--tensorflow/python/keras/engine/saving_test.py7
-rw-r--r--tensorflow/python/keras/engine/topology_test.py2
-rw-r--r--tensorflow/python/keras/engine/training.py65
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py399
-rw-r--r--tensorflow/python/keras/engine/training_eager.py2
-rw-r--r--tensorflow/python/keras/engine/training_eager_test.py14
-rw-r--r--tensorflow/python/keras/engine/training_generator.py11
-rw-r--r--tensorflow/python/keras/engine/training_test.py12
-rw-r--r--tensorflow/python/keras/layers/advanced_activations.py4
-rw-r--r--tensorflow/python/keras/layers/advanced_activations_test.py8
-rw-r--r--tensorflow/python/keras/layers/core.py51
-rw-r--r--tensorflow/python/keras/layers/core_test.py45
-rw-r--r--tensorflow/python/keras/layers/embeddings.py29
-rw-r--r--tensorflow/python/keras/layers/embeddings_test.py13
-rw-r--r--tensorflow/python/keras/metrics.py58
-rw-r--r--tensorflow/python/keras/models.py9
-rw-r--r--tensorflow/python/keras/optimizers_test.py17
-rw-r--r--tensorflow/python/keras/utils/multi_gpu_utils_test.py10
-rw-r--r--tensorflow/python/keras/wrappers/scikit_learn_test.py12
-rw-r--r--tensorflow/python/kernel_tests/BUILD117
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py202
-rw-r--r--tensorflow/python/kernel_tests/basic_gpu_test.py2
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py300
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/broadcast_to_ops_test.py8
-rw-r--r--tensorflow/python/kernel_tests/check_ops_test.py10
-rw-r--r--tensorflow/python/kernel_tests/cond_v2_test.py60
-rw-r--r--tensorflow/python/kernel_tests/conditional_accumulator_test.py4
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py257
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_binary_test.py878
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py1156
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_unary_test.py541
-rw-r--r--tensorflow/python/kernel_tests/dense_update_ops_test.py6
-rw-r--r--tensorflow/python/kernel_tests/depthwise_conv_op_test.py16
-rw-r--r--tensorflow/python/kernel_tests/distributions/bernoulli_test.py12
-rw-r--r--tensorflow/python/kernel_tests/distributions/normal_test.py8
-rw-r--r--tensorflow/python/kernel_tests/extract_volume_patches_op_test.py131
-rw-r--r--tensorflow/python/kernel_tests/functional_ops_test.py10
-rw-r--r--tensorflow/python/kernel_tests/identity_op_py_test.py2
-rw-r--r--tensorflow/python/kernel_tests/init_ops_test.py34
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py24
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py73
-rw-r--r--tensorflow/python/kernel_tests/linalg_grad_test.py2
-rw-r--r--tensorflow/python/kernel_tests/list_ops_test.py26
-rw-r--r--tensorflow/python/kernel_tests/logging_ops_logging_level_test.py70
-rw-r--r--tensorflow/python/kernel_tests/logging_ops_test.py276
-rw-r--r--tensorflow/python/kernel_tests/lookup_ops_test.py70
-rw-r--r--tensorflow/python/kernel_tests/numerics_test.py8
-rw-r--r--tensorflow/python/kernel_tests/random/random_ops_test.py9
-rw-r--r--tensorflow/python/kernel_tests/reduction_ops_test.py6
-rw-r--r--tensorflow/python/kernel_tests/reduction_ops_test_big.py12
-rw-r--r--tensorflow/python/kernel_tests/regex_full_match_op_test.py6
-rw-r--r--tensorflow/python/kernel_tests/regex_replace_op_test.py16
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py8
-rw-r--r--tensorflow/python/kernel_tests/scalar_test.py2
-rw-r--r--tensorflow/python/kernel_tests/scatter_nd_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/scatter_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/segment_reduction_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/softmax_op_test.py21
-rw-r--r--tensorflow/python/kernel_tests/softplus_op_test.py7
-rw-r--r--tensorflow/python/kernel_tests/softsign_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py4
-rw-r--r--tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/string_format_op_test.py384
-rw-r--r--tensorflow/python/kernel_tests/string_length_op_test.py27
-rw-r--r--tensorflow/python/kernel_tests/substr_op_test.py14
-rw-r--r--tensorflow/python/kernel_tests/summary_audio_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/summary_image_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/tensor_array_ops_test.py13
-rw-r--r--tensorflow/python/kernel_tests/unicode_script_op_test.py57
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py4
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py36
-rw-r--r--tensorflow/python/kernel_tests/while_v2_test.py276
-rw-r--r--tensorflow/python/layers/base.py16
-rw-r--r--tensorflow/python/layers/convolutional_test.py36
-rw-r--r--tensorflow/python/layers/core_test.py6
-rw-r--r--tensorflow/python/ops/array_ops.py61
-rw-r--r--tensorflow/python/ops/cond_v2_impl.py6
-rw-r--r--tensorflow/python/ops/control_flow_ops.py23
-rw-r--r--tensorflow/python/ops/distributions/beta.py9
-rw-r--r--tensorflow/python/ops/distributions/bijector_impl.py39
-rw-r--r--tensorflow/python/ops/distributions/dirichlet.py9
-rw-r--r--tensorflow/python/ops/distributions/distribution.py147
-rw-r--r--tensorflow/python/ops/distributions/gamma.py9
-rw-r--r--tensorflow/python/ops/distributions/kullback_leibler.py4
-rw-r--r--tensorflow/python/ops/distributions/normal.py9
-rw-r--r--tensorflow/python/ops/distributions/student_t.py14
-rw-r--r--tensorflow/python/ops/distributions/util.py12
-rw-r--r--tensorflow/python/ops/functional_ops.py40
-rw-r--r--tensorflow/python/ops/gradients_test.py2
-rw-r--r--tensorflow/python/ops/image_ops_impl.py54
-rw-r--r--tensorflow/python/ops/image_ops_test.py12
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_test_util.py16
-rw-r--r--tensorflow/python/ops/logging_ops.py260
-rw-r--r--tensorflow/python/ops/lookup_ops.py40
-rw-r--r--tensorflow/python/ops/losses/util_test.py6
-rw-r--r--tensorflow/python/ops/math_grad.py34
-rw-r--r--tensorflow/python/ops/math_grad_test.py88
-rw-r--r--tensorflow/python/ops/math_ops.py20
-rw-r--r--tensorflow/python/ops/math_ops_test.py71
-rw-r--r--tensorflow/python/ops/matmul_benchmark.py8
-rw-r--r--tensorflow/python/ops/nn_ops.py34
-rw-r--r--tensorflow/python/ops/parallel_for/BUILD2
-rw-r--r--tensorflow/python/ops/parallel_for/gradients.py2
-rw-r--r--tensorflow/python/ops/parallel_for/gradients_test.py26
-rw-r--r--tensorflow/python/ops/parsing_ops.py7
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py74
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py2
-rw-r--r--tensorflow/python/ops/string_ops.py97
-rw-r--r--tensorflow/python/ops/variable_scope.py123
-rw-r--r--tensorflow/python/ops/variables.py323
-rw-r--r--tensorflow/python/ops/while_v2.py584
-rw-r--r--tensorflow/python/profiler/model_analyzer_test.py42
-rw-r--r--tensorflow/python/profiler/pprof_profiler_test.py2
-rw-r--r--tensorflow/python/pywrap_tensorflow.py2
-rw-r--r--tensorflow/python/saved_model/loader_test.py14
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py56
-rw-r--r--tensorflow/python/summary/writer/writer_test.py4
-rw-r--r--tensorflow/python/tools/BUILD9
-rw-r--r--tensorflow/python/tools/api/generator/create_python_api.py1
-rw-r--r--tensorflow/python/tools/freeze_graph_test.py6
-rw-r--r--tensorflow/python/tools/optimize_for_inference_test.py16
-rw-r--r--tensorflow/python/tools/saved_model_cli.py2
-rw-r--r--tensorflow/python/training/adagrad.py2
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py5
-rw-r--r--tensorflow/python/training/checkpointable/util.py2
-rw-r--r--tensorflow/python/training/distribute.py4
-rw-r--r--tensorflow/python/training/evaluation.py68
-rw-r--r--tensorflow/python/training/ftrl_test.py4
-rw-r--r--tensorflow/python/training/gradient_descent_test.py10
-rw-r--r--tensorflow/python/training/learning_rate_decay_test.py4
-rw-r--r--tensorflow/python/training/learning_rate_decay_v2_test.py2
-rw-r--r--tensorflow/python/training/monitored_session.py24
-rw-r--r--tensorflow/python/training/monitored_session_test.py28
-rw-r--r--tensorflow/python/training/optimizer.py2
-rw-r--r--tensorflow/python/training/quantize_training.i7
-rw-r--r--tensorflow/python/training/quantize_training_test.py3
-rw-r--r--tensorflow/python/training/queue_runner_test.py22
-rw-r--r--tensorflow/python/training/saver.py8
-rw-r--r--tensorflow/python/training/saver_test.py249
-rw-r--r--tensorflow/python/training/server_lib_same_variables_no_clear_test.py4
-rw-r--r--tensorflow/python/training/server_lib_test.py18
-rw-r--r--tensorflow/python/training/session_manager_test.py98
-rw-r--r--tensorflow/python/training/supervisor.py7
-rw-r--r--tensorflow/python/training/supervisor_test.py52
-rw-r--r--tensorflow/python/training/sync_replicas_optimizer_test.py17
-rw-r--r--tensorflow/python/training/training_ops_test.py32
-rw-r--r--tensorflow/python/training/training_util_test.py4
-rw-r--r--tensorflow/python/training/warm_starting_util_test.py8
-rw-r--r--tensorflow/python/util/function_utils.py23
-rw-r--r--tensorflow/python/util/function_utils_test.py95
-rw-r--r--tensorflow/python/util/nest.py38
-rw-r--r--tensorflow/python/util/nest_test.py40
-rw-r--r--tensorflow/python/util/util.cc356
-rw-r--r--tensorflow/python/util/util.h9
-rw-r--r--tensorflow/python/util/util.i12
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc129
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.h16
-rw-r--r--tensorflow/stream_executor/device_description.h6
-rw-r--r--tensorflow/stream_executor/dnn.h4
-rw-r--r--tensorflow/stream_executor/lib/array_slice.h8
-rw-r--r--tensorflow/stream_executor/lib/inlined_vector.h4
-rw-r--r--tensorflow/stream_executor/lib/strcat.h6
-rw-r--r--tensorflow/stream_executor/lib/stringpiece.h5
-rw-r--r--tensorflow/stream_executor/plugin_registry.h2
-rw-r--r--tensorflow/stream_executor/stream.cc38
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h11
-rw-r--r--tensorflow/tensorflow.bzl87
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt1
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt9
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt9
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.pbtxt22
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt10
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-variable-scope.pbtxt105
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-variable.-save-slice-info.pbtxt17
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-variable.pbtxt130
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt9
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt9
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt12
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.pbtxt114
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt10
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.variable_scope.pbtxt9
-rw-r--r--tensorflow/tools/api/tests/BUILD1
-rw-r--r--tensorflow/tools/api/tests/api_compatibility_test.py39
-rw-r--r--tensorflow/tools/benchmark/README.md2
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.rocm97
-rw-r--r--tensorflow/tools/ci_build/README.md2
-rwxr-xr-xtensorflow/tools/ci_build/builds/docker_test.sh9
-rwxr-xr-xtensorflow/tools/ci_build/builds/pip.sh4
-rwxr-xr-xtensorflow/tools/ci_build/builds/run_pip_tests.sh3
-rwxr-xr-xtensorflow/tools/ci_build/builds/with_the_same_user6
-rwxr-xr-xtensorflow/tools/ci_build/ci_build.sh11
-rwxr-xr-xtensorflow/tools/ci_build/ci_parameterized_build.sh2
-rwxr-xr-xtensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh4
-rwxr-xr-xtensorflow/tools/ci_build/install/install_pip_packages.sh4
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_cc_core.sh1
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_py2_core.sh1
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh1
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_py3_core.sh1
-rwxr-xr-xtensorflow/tools/ci_build/linux/libtensorflow.sh3
-rwxr-xr-xtensorflow/tools/ci_build/linux/libtensorflow_cpu.sh1
-rwxr-xr-xtensorflow/tools/ci_build/linux/libtensorflow_docker.sh6
-rwxr-xr-x[-rw-r--r--]tensorflow/tools/ci_build/linux/libtensorflow_rocm.sh (renamed from tensorflow/contrib/linalg/python/__init__.py)11
-rwxr-xr-xtensorflow/tools/ci_build/linux/rocm/run_cc_core.sh39
-rwxr-xr-xtensorflow/tools/ci_build/linux/rocm/run_py3_core.sh39
-rwxr-xr-xtensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh1
-rwxr-xr-xtensorflow/tools/ci_build/osx/libtensorflow_cpu.sh1
-rwxr-xr-xtensorflow/tools/ci_build/osx/libtensorflow_gpu.sh1
-rwxr-xr-x[-rw-r--r--]tensorflow/tools/ci_build/osx/libtensorflow_rocm.sh (renamed from tensorflow/contrib/tensorboard/plugins/trace/__init__.py)28
-rw-r--r--tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh2
-rwxr-xr-xtensorflow/tools/ci_build/xla/linux/rocm/run_py3.sh41
-rw-r--r--tensorflow/tools/compatibility/testdata/test_file_v0_11.py2
-rw-r--r--tensorflow/tools/compatibility/tf_upgrade_v2.py8
-rw-r--r--tensorflow/tools/dist_test/README.md2
-rw-r--r--tensorflow/tools/dist_test/server/BUILD1
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu2
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.devel-mkl2
-rw-r--r--tensorflow/tools/docker/jupyter_notebook_config.py2
-rwxr-xr-xtensorflow/tools/docker/parameterized_docker_build.sh2
-rw-r--r--tensorflow/tools/docs/BUILD4
-rw-r--r--tensorflow/tools/docs/generate_lib.py14
-rw-r--r--tensorflow/tools/lib_package/BUILD40
-rw-r--r--tensorflow/tools/pip_package/BUILD33
-rw-r--r--tensorflow/tools/pip_package/pip_smoke_test.py1
-rw-r--r--tensorflow/tools/pip_package/setup.py6
-rw-r--r--tensorflow/tools/quantization/BUILD78
-rw-r--r--tensorflow/tools/quantization/graph_to_dot.py68
-rw-r--r--tensorflow/tools/quantization/quantize_graph.py1302
-rw-r--r--tensorflow/tools/quantization/quantize_graph_test.py966
-rw-r--r--tensorflow/tools/test/check_futures_test.py3
-rwxr-xr-xtensorflow/workspace.bzl405
-rw-r--r--third_party/cub.BUILD1
-rw-r--r--third_party/eigen3/BUILD10
-rw-r--r--third_party/flatbuffers/BUILD.bazel3
-rw-r--r--third_party/flatbuffers/workspace.bzl8
-rw-r--r--third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl158
-rwxr-xr-xthird_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl241
-rw-r--r--third_party/gpus/cuda_configure.bzl2
-rw-r--r--third_party/gpus/rocm/BUILD0
-rw-r--r--third_party/gpus/rocm/BUILD.tpl99
-rw-r--r--third_party/gpus/rocm/build_defs.bzl.tpl45
-rw-r--r--third_party/gpus/rocm/rocm_config.h.tpl21
-rw-r--r--third_party/gpus/rocm_configure.bzl784
-rw-r--r--third_party/icu/BUILD1
-rw-r--r--third_party/icu/BUILD.bazel88
-rw-r--r--third_party/icu/workspace.bzl15
-rw-r--r--third_party/mkl/BUILD23
-rw-r--r--third_party/mkl/build_defs.bzl41
-rw-r--r--third_party/mkl_dnn/BUILD6
-rw-r--r--third_party/mkl_dnn/build_defs.bzl2
-rw-r--r--third_party/ngraph/ngraph.BUILD122
-rw-r--r--third_party/ngraph/ngraph_tf.BUILD67
-rw-r--r--third_party/ngraph/tbb.BUILD63
-rw-r--r--third_party/repo.bzl15
-rw-r--r--third_party/systemlibs/absl_py.BUILD1
-rw-r--r--third_party/systemlibs/absl_py.absl.flags.BUILD11
-rw-r--r--third_party/systemlibs/absl_py.absl.testing.BUILD7
-rw-r--r--third_party/systemlibs/boringssl.BUILD21
-rw-r--r--third_party/systemlibs/double_conversion.BUILD12
-rw-r--r--third_party/systemlibs/gast.BUILD12
-rw-r--r--third_party/systemlibs/google_cloud_cpp.BUILD6
-rw-r--r--third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD7
-rw-r--r--third_party/systemlibs/googleapis.BUILD12
-rw-r--r--third_party/systemlibs/jsoncpp.BUILD2
-rw-r--r--third_party/systemlibs/syslibs_configure.bzl6
-rw-r--r--third_party/toolchains/BUILD2
-rw-r--r--tools/bazel.rc21
1577 files changed, 59810 insertions, 30910 deletions
diff --git a/CODEOWNERS b/CODEOWNERS
index 78f80c8d71..94cc865479 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -2,6 +2,7 @@
/tenosrflow/core/debug @caisq
/tensorflow/core/platform/windows/ @mrry
+/tensorflow/core/platform/s3 @yongtang
/tensorflow/go @asimshankar
/tensorflow/java/ @asimshankar
/tensorflow/python/debug @caisq
@@ -30,14 +31,16 @@
/tensorflow/contrib/gan/ @joel-shor
/tensorflow/contrib/graph_editor/ @purpledog
# NEED OWNER: /tensorflow/contrib/grid_rnn/
+/tensorflow/contrib/hadoop @yongtang
/tensorflow/contrib/hvx/ @satok16
/tensorflow/contrib/integrate/ @shoyer
+/tensorflow/contrib/kafka @yongtang
/tensorflow/contrib/kernel_methods/ @petrosmol
+/tensorflow/contrib/kinesis @yongtang
/tensorflow/contrib/ios_examples/ @petewarden
/tensorflow/contrib/labeled_tensor/ @shoyer
/tensorflow/contrib/layers/ @fchollet @martinwicke
/tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp
-/tensorflow/contrib/linalg/ @langmore
/tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis
/tensorflow/contrib/lookup/ @ysuematsu @andreasst
/tensorflow/contrib/losses/ @alextp @ispirmustafa
diff --git a/README.md b/README.md
index e3092e551e..57efb876c9 100644
--- a/README.md
+++ b/README.md
@@ -29,7 +29,7 @@ subscribing to
[announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce).
## Installation
-*See [Installing TensorFlow](https://www.tensorflow.org/get_started/os_setup.html) for instructions on how to install our release binaries or how to build from source.*
+*See [Installing TensorFlow](https://www.tensorflow.org/install) for instructions on how to install our release binaries or how to build from source.*
People who are a little more adventurous can also try our nightly binaries:
@@ -48,15 +48,12 @@ $ python
```
```python
>>> import tensorflow as tf
+>>> tf.enable_eager_execution()
+>>> tf.add(1, 2)
+3
>>> hello = tf.constant('Hello, TensorFlow!')
->>> sess = tf.Session()
->>> sess.run(hello)
+>>> hello.numpy()
'Hello, TensorFlow!'
->>> a = tf.constant(10)
->>> b = tf.constant(32)
->>> sess.run(a + b)
-42
->>> sess.close()
```
Learn more examples about how to do specific tasks in TensorFlow at the [tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/).
@@ -106,13 +103,13 @@ The TensorFlow project strives to abide by generally accepted best practices in
## For more information
+* [TensorFlow Website](https://www.tensorflow.org)
+* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/)
+* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
+* [TensorFlow Twitter](https://twitter.com/tensorflow)
* [TensorFlow Blog](https://medium.com/tensorflow)
* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si)
-* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
-* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730)
* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
-* [TensorFlow Twitter](https://twitter.com/tensorflow)
-* [TensorFlow Website](https://www.tensorflow.org)
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
diff --git a/RELEASE.md b/RELEASE.md
index bdc23795e5..20e1d9217b 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -1,9 +1,86 @@
+# Release 1.11.0
+
+## Major Features and Improvements
+
+* Nvidia GPU:
+ * Prebuilt binaries are now (as of TensorFlow 1.11) built against cuDNN 7.2 and TensorRT 4. See updated install guides: [Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support)
+* Google Cloud TPU:
+ * Experimental tf.data integration for Keras on Google Cloud TPUs.
+ * Experimental / preview support for eager execution on Google Cloud TPUs.
+* DistributionStrategy:
+ * Add multi-GPU DistributionStrategy support in tf.keras. Users can now use `fit`, `evaluate` and `predict` to distribute their model on multiple GPUs.
+ * Add multi-worker DistributionStrategy and standalone client support in Estimator. See [README] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute) for more details.
+* Add C, C++, and Python functions for querying kernels
+
+## Breaking Changes
+
+* Keras:
+ * The default values for tf.keras `RandomUniform`, `RandomNormal`, and `TruncatedNormal` initializers have been changed to match those in external Keras.
+ * Breaking change: `model.get_config()` on a Sequential model now returns a config dictionary (consistent with other Model instances) instead of a list of configs for the underlying layers.
+
+## Bug Fixes and Other Changes
+
+* C++:
+ * Changed the signature of SessionFactory::NewSession so that it can return a meaningful error message on failure.
+* tf.data:
+ * Remove `num_parallel_parser_calls` argument from `tf.contrib.data.make_csv_dataset()`. [tf.data] Remove `num_parallel_parser_calls` argument from `tf.contrib.data.make_csv_dataset()`.
+ * `tf.data.Dataset.list_files()` raises an exception at initialization time if the argument matches no files.
+ * Renamed BigTable class to BigtableTable for clarity
+ * Document use of the Cloud Bigtable API
+ * Adding `tf.contrib.data.reduce_dataset` which can be used to reduce a dataset to a single element.
+ * Generalization of `tf.contrib.data.sliding_window_batch`.
+* INC:
+ * Runtime improvements to triangular solve.
+* `tf.contrib`:
+ * Add an `implementation` argument to `tf.keras.layers.LocallyConnected2D` and `tf.keras.layers.LocallyConnected1D`. The new mode (`implementation=2`) performs forward pass as a single dense matrix multiplication, allowing dramatic speedups in certain scenarios (but worse performance in others - see docstring). The option also allows to use `padding=same`.
+ * Add documentation clarifying the differences between tf.fill and tf.constant.
+ * Add experimental IndexedDatasets.
+ * Add selective registration target using the lite proto runtime.
+ * Add simple Tensor and DataType classes to TensorFlow Lite Java
+ * Add support for bitcasting to/from uint32 and uint64.
+ * Added a subclass of Estimator that can be created from a SavedModel (SavedModelEstimator).
+ * Adds leaf index modes as an argument.
+ * Allow a different output shape from the input in tf.contrib.image.transform.
+ * Change the state_size order of the StackedRNNCell to be natural order. To keep the existing behavior, user can add reverse_state_order=True when constructing the StackedRNNCells.
+ * Deprecate self.test_session() in favor of self.session() or self.cached_session().
+ * Directly import tensor.proto.h (the transitive import will be removed from tensor.h soon)
+ * Estimator.train() now supports tf.contrib.summary.\* summaries out of the box; each call to .train() will now create a separate tfevents file rather than re-using a shared one.
+ * Fix FTRL L2-shrinkage behavior: the gradient from the L2 shrinkage term should not end up in the accumulator.
+ * Fix toco compilation/execution on Windows
+ * GoogleZoneProvider class added to detect which Google Cloud Engine zone tensorflow is running in.
+ * It is now safe to call any of the C API's TF_Delete\* functions on nullptr
+ * Log some errors on Android to logcat
+ * Match FakeQuant numerics in TFLite to improve accuracy of TFLite quantized inference models.
+ * Optional bucket location check for the GCS Filesystem.
+ * Performance enhancements for StringSplitOp & StringSplitV2Op.
+ * Performance improvements for regex replace operations.
+ * TFRecordWriter now raises an error if .write() fails.
+ * TPU: More helpful error messages in TPUClusterResolvers.
+ * The legacy_init_op argument to SavedModelBuilder methods for adding MetaGraphs has been deprecated. Please use the equivalent main_op argument instead. As part of this, we now explicitly check for a single main_op or legacy_init_op at the time of SavedModel building, whereas the check on main_op was previously only done at load time.
+ * The protocol used for Estimator training is now configurable in RunConfig.
+ * Triangular solve performance improvements.
+ * Unify RNN cell interface between TF and Keras. Add new get_initial_state() to Keras and TF RNN cell, which will use to replace the existing zero_state() method.
+ * Update initialization of variables in Keras.
+ * Updates to "constrained_optimization" in tensorflow/contrib.
+ * boosted trees: adding pruning mode
+ * tf.train.Checkpoint does not delete old checkpoints by default.
+ * tfdbg: Limit the total disk space occupied by dumped tensor data to 100 GBytes. Add environment variable `TFDBG_DISK_BYTES_LIMIT` to allow adjustment of this upper limit.
+
+## Thanks to our Contributors
+
+This release contains contributions from many people at Google, as well as:
+
+Aapeli, adoda, Ag Ramesh, Amogh Mannekote, Andrew Gibiansky, Andy Craze, Anirudh Koul, Aurelien Geron, Avijit, Avijit-Nervana, Ben, Benjamin H. Myara, bhack, Brett Koonce, Cao Zongyan, cbockman, cheerss, Chikanaga Tomoyuki, Clayne Robison, cosine0, Cui Wei, Dan J, David, David Norman, Dmitry Klimenkov, Eliel Hojman, Florian Courtial, fo40225, formath, Geoffrey Irving, gracehoney, Grzegorz Pawelczak, Guoliang Hua, Guozhong Zhuang, Herman Zvonimir DošIlović, HuiyangFei, Jacker, Jan HüNnemeyer, Jason Taylor, Jason Zaman, Jesse, Jiang,Zhoulong, Jiawei Zhang, Jie, Joe Yearsley, Johannes Schmitz, Jon Perl, Jon Triebenbach, Jonathan, Jonathan Hseu, Jongmin Park, Justin Shenk, karl@kubx.ca, Kate Hodesdon, Kb Sriram, Keishi Hattori, Kenneth Blomqvist, Koan-Sin Tan, Li Liangbin, Li, Yiqiang, Loo Rong Jie, Madiyar, Mahmoud Abuzaina, Mark Ryan, Matt Dodge, mbhuiyan, melvinljy96, Miguel Mota, Nafis Sadat, Nathan Luehr, naurril, Nehal J Wani, Niall Moran, Niranjan Hasabnis, Nishidha Panpaliya, npow, olicht, Pei Zhang, Peng Wang (Simpeng), Peng Yu, Philipp Jund, Pradeep Banavara, Pratik Kalshetti, qwertWZ, Rakesh Chada, Randy West, Ray Kim, Rholais Lii, Robin Richtsfeld, Rodrigo Silveira, Ruizhi, Santosh Kumar, Seb Bro, Sergei Lebedev, sfujiwara, Shaba Abhiram, Shashi, SneakyFish5, Soila Kavulya, Stefan Dyulgerov, Steven Winston, Sunitha Kambhampati, Surry Shome, Taehoon Lee, Thor Johnsen, Tristan Rice, TShapinsky, tucan, tucan9389, Vicente Reyes, Vilmar-Hillow, Vitaly Lavrukhin, wangershi, weidan.kong, weidankong, Wen-Heng (Jack) Chung, William D. Irons, Wim Glenn, XFeiF, Yan Facai (颜发才), Yanbo Liang, Yong Tang, Yoshihiro Yamazaki, Yuan (Terry) Tang, Yuan, Man, zhaoyongke, ÁRon
+Ricardo Perez-Lopez, 张天启, 张晓飞
+
+
# Release 1.10.1
## Bug Fixes and Other Changes
* `tf.keras`:
* Fixing keras on Cloud TPUs. No new binaries will be built for Windows.
+
# Release 1.10.0
## Major Features And Improvements
@@ -17,7 +94,7 @@
## Breaking Changes
-* Prebuilt binaries are now (as of TensorFlow 1.10) built against NCCL 2.2 and no longer include NCCL in the binary install. TensorFlow usage with multiple GPUs and NCCL requires upgrade to [NCCL 2.2](https://developer.nvidia.com/nccl). See updated install guides: [Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support) and [Install TensorFlow from Sources](https://www.tensorflow.org/install/install_sources#optional_install_tensorflow_for_gpu_prerequisites).
+* Prebuilt binaries are now (as of TensorFlow 1.10) built against NCCL 2.2 and no longer include NCCL in the binary install. TensorFlow usage with multiple GPUs and NCCL requires upgrade to [NCCL 2.2](https://developer.nvidia.com/nccl). See updated install guides: [TensorFlow GPU support](https://www.tensorflow.org/install/gpu) and [Build TensorFlow from source](https://www.tensorflow.org/install/source).
* Starting from TensorFlow 1.11, Windows builds will use Bazel. Therefore, we will drop official support for cmake.
## Bug Fixes and Other Changes
diff --git a/configure.py b/configure.py
index e9d162fbd2..9899ae10e8 100644
--- a/configure.py
+++ b/configure.py
@@ -41,7 +41,6 @@ _DEFAULT_CUDA_PATH = '/usr/local/cuda'
_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
_DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION)
-_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine()
_TF_OPENCL_VERSION = '1.2'
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
@@ -54,6 +53,11 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
_TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME)
_TF_WORKSPACE = os.path.join(_TF_WORKSPACE_ROOT, 'WORKSPACE')
+if platform.machine() == 'ppc64le':
+ _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/powerpc64le-linux-gnu/'
+else:
+ _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine()
+
class UserInputError(Exception):
pass
@@ -153,14 +157,18 @@ def get_python_path(environ_cp, python_bin_path):
if environ_cp.get('PYTHONPATH'):
python_paths = environ_cp.get('PYTHONPATH').split(':')
try:
- library_paths = run_shell(
- [python_bin_path, '-c',
- 'import site; print("\\n".join(site.getsitepackages()))']).split('\n')
+ library_paths = run_shell([
+ python_bin_path, '-c',
+ 'import site; print("\\n".join(site.getsitepackages()))'
+ ]).split('\n')
except subprocess.CalledProcessError:
- library_paths = [run_shell(
- [python_bin_path, '-c',
- 'from distutils.sysconfig import get_python_lib;'
- 'print(get_python_lib())'])]
+ library_paths = [
+ run_shell([
+ python_bin_path, '-c',
+ 'from distutils.sysconfig import get_python_lib;'
+ 'print(get_python_lib())'
+ ])
+ ]
all_paths = set(python_paths + library_paths)
@@ -187,8 +195,7 @@ def setup_python(environ_cp):
environ_cp, 'PYTHON_BIN_PATH', ask_python_bin_path,
default_python_bin_path)
# Check if the path is valid
- if os.path.isfile(python_bin_path) and os.access(
- python_bin_path, os.X_OK):
+ if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK):
break
elif not os.path.exists(python_bin_path):
print('Invalid python path: %s cannot be found.' % python_bin_path)
@@ -230,8 +237,9 @@ def setup_python(environ_cp):
environ_cp['PYTHON_BIN_PATH'] = python_bin_path
# Write tools/python_bin_path.sh
- with open(os.path.join(
- _TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), 'w') as f:
+ with open(
+ os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'),
+ 'w') as f:
f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path)
@@ -250,7 +258,7 @@ def reset_tf_configure_bazelrc(workspace_path):
continue
f.write('%s\n' % l)
if is_windows():
- tf_bazelrc_path = _TF_BAZELRC.replace("\\", "/")
+ tf_bazelrc_path = _TF_BAZELRC.replace('\\', '/')
else:
tf_bazelrc_path = _TF_BAZELRC
f.write('import %s\n' % tf_bazelrc_path)
@@ -261,8 +269,8 @@ def cleanup_makefile():
These files could interfere with Bazel parsing.
"""
- makefile_download_dir = os.path.join(
- _TF_WORKSPACE_ROOT, 'tensorflow', 'contrib', 'makefile', 'downloads')
+ makefile_download_dir = os.path.join(_TF_WORKSPACE_ROOT, 'tensorflow',
+ 'contrib', 'makefile', 'downloads')
if os.path.isdir(makefile_download_dir):
for root, _, filenames in os.walk(makefile_download_dir):
for f in filenames:
@@ -330,9 +338,8 @@ def get_var(environ_cp,
'Environment variable %s must be set as a boolean indicator.\n'
'The following are accepted as TRUE : %s.\n'
'The following are accepted as FALSE: %s.\n'
- 'Current value is %s.' % (
- var_name, ', '.join(true_strings), ', '.join(false_strings),
- var))
+ 'Current value is %s.' % (var_name, ', '.join(true_strings),
+ ', '.join(false_strings), var))
while var is None:
user_input_origin = get_input(question)
@@ -355,8 +362,12 @@ def get_var(environ_cp,
return var
-def set_build_var(environ_cp, var_name, query_item, option_name,
- enabled_by_default, bazel_config_name=None):
+def set_build_var(environ_cp,
+ var_name,
+ query_item,
+ option_name,
+ enabled_by_default,
+ bazel_config_name=None):
"""Set if query_item will be enabled for the build.
Ask user if query_item will be enabled. Default is used if no input is given.
@@ -379,8 +390,8 @@ def set_build_var(environ_cp, var_name, query_item, option_name,
elif bazel_config_name is not None:
# TODO(mikecase): Migrate all users of configure.py to use --config Bazel
# options and not to set build configs through environment variables.
- write_to_bazelrc('build:%s --define %s=true'
- % (bazel_config_name, option_name))
+ write_to_bazelrc(
+ 'build:%s --define %s=true' % (bazel_config_name, option_name))
def set_action_env_var(environ_cp,
@@ -447,7 +458,8 @@ def check_bazel_version(min_version):
if which('bazel') is None:
print('Cannot find bazel. Please install bazel.')
sys.exit(0)
- curr_version = run_shell(['bazel', '--batch', '--bazelrc=/dev/null', 'version'])
+ curr_version = run_shell(
+ ['bazel', '--batch', '--bazelrc=/dev/null', 'version'])
for line in curr_version.split('\n'):
if 'Build label: ' in line:
@@ -499,6 +511,7 @@ def set_cc_opt_flags(environ_cp):
write_to_bazelrc('build:opt --host_copt=-march=native')
write_to_bazelrc('build:opt --define with_default_optimizations=true')
+
def set_tf_cuda_clang(environ_cp):
"""set TF_CUDA_CLANG action_env.
@@ -581,16 +594,14 @@ def set_clang_cuda_compiler_path(environ_cp):
clang_cuda_compiler_path)
-def prompt_loop_or_load_from_env(
- environ_cp,
- var_name,
- var_default,
- ask_for_var,
- check_success,
- error_msg,
- suppress_default_error=False,
- n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS
-):
+def prompt_loop_or_load_from_env(environ_cp,
+ var_name,
+ var_default,
+ ask_for_var,
+ check_success,
+ error_msg,
+ suppress_default_error=False,
+ n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS):
"""Loop over user prompts for an ENV param until receiving a valid response.
For the env param var_name, read from the environment or verify user input
@@ -629,9 +640,7 @@ def prompt_loop_or_load_from_env(
)
for _ in range(n_ask_attempts):
- val = get_from_env_or_user_or_default(environ_cp,
- var_name,
- full_query,
+ val = get_from_env_or_user_or_default(environ_cp, var_name, full_query,
default)
if check_success(val):
break
@@ -639,9 +648,9 @@ def prompt_loop_or_load_from_env(
print(error_msg % val)
environ_cp[var_name] = ''
else:
- raise UserInputError('Invalid %s setting was provided %d times in a row. '
- 'Assuming to be a scripting mistake.' %
- (var_name, n_ask_attempts))
+ raise UserInputError(
+ 'Invalid %s setting was provided %d times in a row. '
+ 'Assuming to be a scripting mistake.' % (var_name, n_ask_attempts))
environ_cp[var_name] = val
return val
@@ -650,8 +659,8 @@ def prompt_loop_or_load_from_env(
def create_android_ndk_rule(environ_cp):
"""Set ANDROID_NDK_HOME and write Android NDK WORKSPACE rule."""
if is_windows() or is_cygwin():
- default_ndk_path = cygpath('%s/Android/Sdk/ndk-bundle' %
- environ_cp['APPDATA'])
+ default_ndk_path = cygpath(
+ '%s/Android/Sdk/ndk-bundle' % environ_cp['APPDATA'])
elif is_macos():
default_ndk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME']
else:
@@ -668,8 +677,7 @@ def create_android_ndk_rule(environ_cp):
ask_for_var='Please specify the home path of the Android NDK to use.',
check_success=valid_ndk_path,
error_msg=('The path %s or its child file "source.properties" '
- 'does not exist.')
- )
+ 'does not exist.'))
write_action_env_to_bazelrc('ANDROID_NDK_HOME', android_ndk_home_path)
write_action_env_to_bazelrc('ANDROID_NDK_API_LEVEL',
check_ndk_level(android_ndk_home_path))
@@ -703,9 +711,9 @@ def create_android_sdk_rule(environ_cp):
api_levels = [x.replace('android-', '') for x in api_levels]
def valid_api_level(api_level):
- return os.path.exists(os.path.join(android_sdk_home_path,
- 'platforms',
- 'android-' + api_level))
+ return os.path.exists(
+ os.path.join(android_sdk_home_path, 'platforms',
+ 'android-' + api_level))
android_api_level = prompt_loop_or_load_from_env(
environ_cp,
@@ -720,9 +728,8 @@ def create_android_sdk_rule(environ_cp):
versions = sorted(os.listdir(build_tools))
def valid_build_tools(version):
- return os.path.exists(os.path.join(android_sdk_home_path,
- 'build-tools',
- version))
+ return os.path.exists(
+ os.path.join(android_sdk_home_path, 'build-tools', version))
android_build_tools_version = prompt_loop_or_load_from_env(
environ_cp,
@@ -736,10 +743,8 @@ def create_android_sdk_rule(environ_cp):
write_action_env_to_bazelrc('ANDROID_BUILD_TOOLS_VERSION',
android_build_tools_version)
- write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL',
- android_api_level)
- write_action_env_to_bazelrc('ANDROID_SDK_HOME',
- android_sdk_home_path)
+ write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL', android_api_level)
+ write_action_env_to_bazelrc('ANDROID_SDK_HOME', android_sdk_home_path)
def check_ndk_level(android_ndk_home_path):
@@ -798,6 +803,7 @@ def reformat_version_sequence(version_str, sequence_count):
Args:
version_str: String, the version string.
sequence_count: int, an integer.
+
Returns:
string, reformatted version string.
"""
@@ -841,12 +847,19 @@ def set_tf_cuda_version(environ_cp):
if is_windows():
cuda_rt_lib_paths = ['lib/x64/cudart.lib']
elif is_linux():
- cuda_rt_lib_paths = ['%s/libcudart.so.%s' % (x, tf_cuda_version)
- for x in ['lib64', 'lib/x86_64-linux-gnu']]
+ cuda_rt_lib_paths = [
+ '%s/libcudart.so.%s' % (x, tf_cuda_version) for x in [
+ 'lib64',
+ 'lib/powerpc64le-linux-gnu',
+ 'lib/x86_64-linux-gnu',
+ ]
+ ]
elif is_macos():
cuda_rt_lib_paths = ['lib/libcudart.%s.dylib' % tf_cuda_version]
- cuda_toolkit_paths_full = [os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths]
+ cuda_toolkit_paths_full = [
+ os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths
+ ]
if any([os.path.exists(x) for x in cuda_toolkit_paths_full]):
break
@@ -919,8 +932,8 @@ def set_tf_cudnn_version(environ_cp):
cudnn_path_from_ldconfig)
if cudnn_path_from_ldconfig:
cudnn_path_from_ldconfig = cudnn_path_from_ldconfig.group(1)
- if os.path.exists('%s.%s' % (cudnn_path_from_ldconfig,
- tf_cudnn_version)):
+ if os.path.exists(
+ '%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)):
cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig)
break
@@ -1166,6 +1179,7 @@ def get_native_cuda_compute_capabilities(environ_cp):
Args:
environ_cp: copy of the os.environ.
+
Returns:
string of native cuda compute capabilities, separated by comma.
"""
@@ -1290,8 +1304,7 @@ def set_computecpp_toolkit_path(environ_cp):
else:
sycl_rt_lib_path = ''
- sycl_rt_lib_path_full = os.path.join(toolkit_path,
- sycl_rt_lib_path)
+ sycl_rt_lib_path_full = os.path.join(toolkit_path, sycl_rt_lib_path)
exists = os.path.exists(sycl_rt_lib_path_full)
if not exists:
print('Invalid SYCL %s library path. %s cannot be found' %
@@ -1319,8 +1332,8 @@ def set_trisycl_include_dir(environ_cp):
ask_trisycl_include_dir = ('Please specify the location of the triSYCL '
'include directory. (Use --config=sycl_trisycl '
'when building with Bazel) '
- '[Default is %s]: '
- ) % (_DEFAULT_TRISYCL_INCLUDE_DIR)
+ '[Default is %s]: ') % (
+ _DEFAULT_TRISYCL_INCLUDE_DIR)
while True:
trisycl_include_dir = get_from_env_or_user_or_default(
@@ -1329,13 +1342,12 @@ def set_trisycl_include_dir(environ_cp):
if os.path.exists(trisycl_include_dir):
break
- print('Invalid triSYCL include directory, %s cannot be found'
- % (trisycl_include_dir))
+ print('Invalid triSYCL include directory, %s cannot be found' %
+ (trisycl_include_dir))
# Set TRISYCL_INCLUDE_DIR
environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir
- write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR',
- trisycl_include_dir)
+ write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir)
def set_mpi_home(environ_cp):
@@ -1345,8 +1357,9 @@ def set_mpi_home(environ_cp):
default_mpi_home = os.path.dirname(os.path.dirname(default_mpi_home))
def valid_mpi_path(mpi_home):
- exists = (os.path.exists(os.path.join(mpi_home, 'include')) and
- os.path.exists(os.path.join(mpi_home, 'lib')))
+ exists = (
+ os.path.exists(os.path.join(mpi_home, 'include')) and
+ os.path.exists(os.path.join(mpi_home, 'lib')))
if not exists:
print('Invalid path to the MPI Toolkit. %s or %s cannot be found' %
(os.path.join(mpi_home, 'include'),
@@ -1395,16 +1408,22 @@ def set_other_mpi_vars(environ_cp):
raise ValueError('Cannot find the MPI library file in %s/lib' % mpi_home)
-def set_grpc_build_flags():
- write_to_bazelrc('build --define grpc_no_ares=true')
-
-
def set_system_libs_flag(environ_cp):
syslibs = environ_cp.get('TF_SYSTEM_LIBS', '')
- syslibs = ','.join(sorted(syslibs.split(',')))
if syslibs and syslibs != '':
+ if ',' in syslibs:
+ syslibs = ','.join(sorted(syslibs.split(',')))
+ else:
+ syslibs = ','.join(sorted(syslibs.split()))
write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs)
+ if 'PREFIX' in environ_cp:
+ write_to_bazelrc('build --define=PREFIX=%s' % environ_cp['PREFIX'])
+ if 'LIBDIR' in environ_cp:
+ write_to_bazelrc('build --define=LIBDIR=%s' % environ_cp['LIBDIR'])
+ if 'INCLUDEDIR' in environ_cp:
+ write_to_bazelrc('build --define=INCLUDEDIR=%s' % environ_cp['INCLUDEDIR'])
+
def set_windows_build_flags(environ_cp):
"""Set Windows specific build options."""
@@ -1421,14 +1440,20 @@ def set_windows_build_flags(environ_cp):
# TODO(pcloudy): Remove this flag when upgrading Bazel to 0.16.0
# Short object file path will be enabled by default.
write_to_bazelrc('build --experimental_shortened_obj_file_path=true')
+ # When building zip file for some py_binary and py_test targets, don't
+ # include its dependencies. This is for:
+ # 1. Running python tests against the system installed TF pip package.
+ # 2. Avoiding redundant files in
+ # //tensorflow/tools/pip_package:simple_console_windows,
+ # which is a py_binary used during creating TF pip package.
+ # See https://github.com/tensorflow/tensorflow/issues/22390
+ write_to_bazelrc('build --define=no_tensorflow_py_deps=true')
if get_var(
environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline',
- True,
- ('Would you like to override eigen strong inline for some C++ '
- 'compilation to reduce the compilation time?'),
- 'Eigen strong inline overridden.',
- 'Not overriding eigen strong inline, '
+ True, ('Would you like to override eigen strong inline for some C++ '
+ 'compilation to reduce the compilation time?'),
+ 'Eigen strong inline overridden.', 'Not overriding eigen strong inline, '
'some compilations could take more than 20 mins.'):
# Due to a known MSVC compiler issue
# https://github.com/tensorflow/tensorflow/issues/10521
@@ -1445,10 +1470,11 @@ def config_info_line(name, help_text):
def main():
parser = argparse.ArgumentParser()
- parser.add_argument("--workspace",
- type=str,
- default=_TF_WORKSPACE_ROOT,
- help="The absolute path to your active Bazel workspace.")
+ parser.add_argument(
+ '--workspace',
+ type=str,
+ default=_TF_WORKSPACE_ROOT,
+ help='The absolute path to your active Bazel workspace.')
args = parser.parse_args()
# Make a copy of os.environ to be clear when functions and getting and setting
@@ -1462,11 +1488,7 @@ def main():
setup_python(environ_cp)
if is_windows():
- environ_cp['TF_NEED_AWS'] = '0'
- environ_cp['TF_NEED_GCP'] = '0'
- environ_cp['TF_NEED_HDFS'] = '0'
environ_cp['TF_NEED_JEMALLOC'] = '0'
- environ_cp['TF_NEED_KAFKA'] = '0'
environ_cp['TF_NEED_OPENCL_SYCL'] = '0'
environ_cp['TF_NEED_COMPUTECPP'] = '0'
environ_cp['TF_NEED_OPENCL'] = '0'
@@ -1476,40 +1498,26 @@ def main():
# Windows.
environ_cp['TF_DOWNLOAD_CLANG'] = '0'
environ_cp['TF_ENABLE_XLA'] = '0'
- environ_cp['TF_NEED_GDR'] = '0'
- environ_cp['TF_NEED_VERBS'] = '0'
environ_cp['TF_NEED_MPI'] = '0'
environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0'
if is_macos():
environ_cp['TF_NEED_JEMALLOC'] = '0'
environ_cp['TF_NEED_TENSORRT'] = '0'
+ environ_cp['TF_ENABLE_XLA'] = '0'
# The numpy package on ppc64le uses OpenBLAS which has multi-threading
# issues that lead to incorrect answers. Set OMP_NUM_THREADS=1 at
# runtime to allow the Tensorflow testcases which compare numpy
# results to Tensorflow results to succeed.
if is_ppc64le():
- write_action_env_to_bazelrc("OMP_NUM_THREADS", 1)
+ write_action_env_to_bazelrc('OMP_NUM_THREADS', 1)
set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc',
'with_jemalloc', True)
- set_build_var(environ_cp, 'TF_NEED_GCP', 'Google Cloud Platform',
- 'with_gcp_support', True, 'gcp')
- set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System',
- 'with_hdfs_support', True, 'hdfs')
- set_build_var(environ_cp, 'TF_NEED_AWS', 'Amazon AWS Platform',
- 'with_aws_support', True, 'aws')
- set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform',
- 'with_kafka_support', True, 'kafka')
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
- False, 'xla')
- set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support',
- False, 'gdr')
- set_build_var(environ_cp, 'TF_NEED_VERBS', 'VERBS', 'with_verbs_support',
- False, 'verbs')
- set_build_var(environ_cp, 'TF_NEED_NGRAPH', 'nGraph',
- 'with_ngraph_support', False, 'ngraph')
+ True, 'xla')
+
set_action_env_var(environ_cp, 'TF_NEED_OPENCL_SYCL', 'OpenCL SYCL', False)
if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1':
@@ -1521,6 +1529,13 @@ def main():
else:
set_trisycl_include_dir(environ_cp)
+ set_action_env_var(environ_cp, 'TF_NEED_ROCM', 'ROCm', False)
+ if (environ_cp.get('TF_NEED_ROCM') == '1' and
+ 'LD_LIBRARY_PATH' in environ_cp and
+ environ_cp.get('LD_LIBRARY_PATH') != '1'):
+ write_action_env_to_bazelrc('LD_LIBRARY_PATH',
+ environ_cp.get('LD_LIBRARY_PATH'))
+
set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)
if (environ_cp.get('TF_NEED_CUDA') == '1' and
'TF_CUDA_CONFIG_REPO' not in environ_cp):
@@ -1561,12 +1576,24 @@ def main():
write_to_bazelrc('build --config=download_clang')
write_to_bazelrc('test --config=download_clang')
+ # SYCL / ROCm / CUDA are mutually exclusive.
+ # At most 1 GPU platform can be configured.
+ gpu_platform_count = 0
+ if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1':
+ gpu_platform_count += 1
+ if environ_cp.get('TF_NEED_ROCM') == '1':
+ gpu_platform_count += 1
+ if environ_cp.get('TF_NEED_CUDA') == '1':
+ gpu_platform_count += 1
+ if gpu_platform_count >= 2:
+ raise UserInputError('SYCL / CUDA / ROCm are mututally exclusive. '
+ 'At most 1 GPU platform can be configured.')
+
set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False)
if environ_cp.get('TF_NEED_MPI') == '1':
set_mpi_home(environ_cp)
set_other_mpi_vars(environ_cp)
- set_grpc_build_flags()
set_cc_opt_flags(environ_cp)
set_system_libs_flag(environ_cp)
if is_windows():
@@ -1575,13 +1602,10 @@ def main():
# Add a config option to build TensorFlow 2.0 API.
write_to_bazelrc('build:v2 --define=tf_api_version=2')
- if get_var(
- environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace',
- False,
- ('Would you like to interactively configure ./WORKSPACE for '
- 'Android builds?'),
- 'Searching for NDK and SDK installations.',
- 'Not configuring the WORKSPACE for Android builds.'):
+ if get_var(environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', False,
+ ('Would you like to interactively configure ./WORKSPACE for '
+ 'Android builds?'), 'Searching for NDK and SDK installations.',
+ 'Not configuring the WORKSPACE for Android builds.'):
create_android_ndk_rule(environ_cp)
create_android_sdk_rule(environ_cp)
@@ -1594,6 +1618,10 @@ def main():
'more details.')
config_info_line('mkl', 'Build with MKL support.')
config_info_line('monolithic', 'Config for mostly static monolithic build.')
+ config_info_line('gdr', 'Build with GDR support.')
+ config_info_line('verbs', 'Build with libverbs support.')
+ config_info_line('ngraph', 'Build with Intel nGraph support.')
+
if __name__ == '__main__':
main()
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 386e0096ff..5f73da68a2 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -225,60 +225,6 @@ config_setting(
)
config_setting(
- name = "with_gcp_support",
- define_values = {"with_gcp_support": "true"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_hdfs_support",
- define_values = {"with_hdfs_support": "true"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_aws_support",
- define_values = {"with_aws_support": "true"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_kafka_support",
- define_values = {"with_kafka_support": "true"},
- visibility = ["//visibility:public"],
-)
-
-# Crosses between platforms and file system libraries not supported on those
-# platforms due to limitations in nested select() statements.
-config_setting(
- name = "with_gcp_support_windows_override",
- define_values = {"with_gcp_support": "true"},
- values = {"cpu": "x64_windows"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_hdfs_support_windows_override",
- define_values = {"with_hdfs_support": "true"},
- values = {"cpu": "x64_windows"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_aws_support_windows_override",
- define_values = {"with_aws_support": "true"},
- values = {"cpu": "x64_windows"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_kafka_support_windows_override",
- define_values = {"with_kafka_support": "true"},
- values = {"cpu": "x64_windows"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
name = "with_cuda_support_windows_override",
define_values = {"using_cuda_nvcc": "true"},
values = {"cpu": "x64_windows"},
@@ -286,48 +232,6 @@ config_setting(
)
config_setting(
- name = "with_gcp_support_android_override",
- define_values = {"with_gcp_support": "true"},
- values = {"crosstool_top": "//external:android/crosstool"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_hdfs_support_android_override",
- define_values = {"with_hdfs_support": "true"},
- values = {"crosstool_top": "//external:android/crosstool"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_aws_support_android_override",
- define_values = {"with_aws_support": "true"},
- values = {"crosstool_top": "//external:android/crosstool"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_gcp_support_ios_override",
- define_values = {"with_gcp_support": "true"},
- values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_hdfs_support_ios_override",
- define_values = {"with_hdfs_support": "true"},
- values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_aws_support_ios_override",
- define_values = {"with_aws_support": "true"},
- values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
name = "with_xla_support",
define_values = {"with_xla_support": "true"},
visibility = ["//visibility:public"],
@@ -564,6 +468,7 @@ tf_cc_shared_object(
"$(location //tensorflow/c:version_script.lds)",
],
}),
+ visibility = ["//visibility:public"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
@@ -588,6 +493,7 @@ tf_cc_shared_object(
"$(location //tensorflow:tf_version_script.lds)",
],
}),
+ visibility = ["//visibility:public"],
deps = [
"//tensorflow:tf_exported_symbols.lds",
"//tensorflow:tf_version_script.lds",
@@ -608,6 +514,55 @@ exports_files(
],
)
+genrule(
+ name = "install_headers",
+ srcs = [
+ "//tensorflow/c:headers",
+ "//tensorflow/c/eager:headers",
+ "//tensorflow/cc:headers",
+ "//tensorflow/core:headers",
+ ],
+ outs = ["include"],
+ cmd = """
+ mkdir $@
+ for f in $(SRCS); do
+ d="$${f%/*}"
+ d="$${d#bazel-out*genfiles/}"
+ d="$${d#*external/eigen_archive/}"
+
+ if [[ $${d} == *local_config_* ]]; then
+ continue
+ fi
+
+ if [[ $${d} == external* ]]; then
+ extname="$${d#*external/}"
+ extname="$${extname%%/*}"
+ if [[ $${TF_SYSTEM_LIBS:-} == *$${extname}* ]]; then
+ continue
+ fi
+ fi
+
+ mkdir -p "$@/$${d}"
+ cp "$${f}" "$@/$${d}/"
+ done
+ """,
+ tags = ["manual"],
+ visibility = ["//visibility:public"],
+)
+
+genrule(
+ name = "root_init_gen",
+ srcs = select({
+ "api_version_2": [":tf_python_api_gen_v2"],
+ "//conditions:default": [":tf_python_api_gen_v1"],
+ }),
+ outs = ["__init__.py"],
+ cmd = select({
+ "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)",
+ "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)",
+ }),
+)
+
gen_api_init_files(
name = "tf_python_api_gen_v1",
srcs = ["api_template.__init__.py"],
@@ -629,19 +584,6 @@ gen_api_init_files(
root_init_template = "api_template.__init__.py",
)
-genrule(
- name = "root_init_gen",
- srcs = select({
- "api_version_2": [":tf_python_api_gen_v2"],
- "//conditions:default": [":tf_python_api_gen_v1"],
- }),
- outs = ["__init__.py"],
- cmd = select({
- "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)",
- "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)",
- }),
-)
-
py_library(
name = "tensorflow_py",
srcs = ["//tensorflow/python/estimator/api:estimator_python_api_gen"],
diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py
index 53a72b8443..2de740e145 100644
--- a/tensorflow/api_template.__init__.py
+++ b/tensorflow/api_template.__init__.py
@@ -14,9 +14,9 @@
# ==============================================================================
"""Bring in all of the public TensorFlow interface into this module."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+from __future__ import absolute_import as _absolute_import
+from __future__ import division as _division
+from __future__ import print_function as _print_function
import os as _os
@@ -41,6 +41,11 @@ except (ImportError, AttributeError):
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
+# The templated code that replaces the placeholder above sometimes
+# sets the __all__ variable. If it does, we have to be sure to add
+# "contrib".
+if '__all__' in vars():
+ vars()['__all__'].append('contrib')
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
app.flags = flags # pylint: disable=undefined-variable
@@ -51,10 +56,6 @@ _tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disabl
if _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
-del absolute_import
-del division
-del print_function
-
# These symbols appear because we import the python package which
# in turn imports from tensorflow.core and tensorflow.python. They
# must come from this module. So python adds these symbols for the
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 43c279bd80..17e2e292eb 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -246,6 +246,7 @@ tf_cc_test(
":c_api_experimental",
":c_test_util",
"//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 3bcc62cf2d..d4b78138e9 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
using tensorflow::FunctionDef;
using tensorflow::Node;
@@ -8508,6 +8509,20 @@ void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
VLOG(1) << "Enqueuing is done.";
}
+TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) {
+ tensorflow::ServerDef server_def;
+ if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto,
+ &server_def)) {
+ status->status = tensorflow::errors::Internal(
+ "Invalid text proto for ServerDef: ", text_proto);
+ return nullptr;
+ }
+ status->status = tensorflow::Status();
+ TF_Buffer* ret = TF_NewBuffer();
+ TF_CHECK_OK(MessageToBuffer(server_def, ret));
+ return ret;
+}
+
TFE_Context* TFE_CreateContextFromSession(TF_Session* session,
TF_Status* status) {
auto* opts = TFE_NewContextOptions();
@@ -8723,35 +8738,7 @@ void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
TF_DeleteStatus(status);
}
-TFE_TensorHandle* TFE_RunConstOp(TFE_Context* ctx) {
- // Intentionally LOG into INFO below for ease of debugging.
- VLOG(1) << "TFE_RunConstOp called";
-
- auto* status = TF_NewStatus();
- auto* op = TFE_NewOp(ctx, "Const", status);
- CheckOk(status);
- TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
-
- auto* tensor =
- TF_AllocateTensor(TF_FLOAT, /*shape.data()*/ nullptr, /*shape.size()*/ 0,
- TF_DataTypeSize(TF_FLOAT) * 1);
- auto* ptr = reinterpret_cast<char*>(TF_TensorData(tensor));
- *reinterpret_cast<float*>(ptr) = 17.0;
-
- TFE_OpSetAttrTensor(op, "value", tensor, status);
- CheckOk(status);
- TF_DeleteTensor(tensor);
- VLOG(1) << "New op created";
-
- TFE_TensorHandle* retval;
- int num_retvals = 1;
- TFE_Execute(op, &retval, &num_retvals, status);
- CheckOk(status);
- CHECK_EQ(num_retvals, 1);
- VLOG(1) << "Op executed";
-
- TFE_DeleteOp(op);
- TF_DeleteStatus(status);
-
- return retval;
+TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
+ const char* errMsg) {
+ status->status = tensorflow::errors::Internal(errMsg);
}
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index a3ca847d96..d98d532e32 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -131,6 +131,8 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
int tensor_id,
TF_Tensor* tensor,
TF_Status* status);
+// Create a serialized tensorflow.ServerDef proto.
+TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status);
// TODO: remove this API in favor of the next one.
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession(
@@ -178,10 +180,8 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString(
TFE_TensorHandle* handle);
-// Returns a const scalar tensor.
-// Caller owns both the input and the output tensor handles.
-// TODO: Remove this API with hard-coded tensor computation.
-TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_RunConstOp(TFE_Context* ctx);
+TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
+ const char* errMsg);
#ifdef __cplusplus
} /* end extern "C" */
diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc
index 30fcfd401d..c6effd3969 100644
--- a/tensorflow/c/c_api_experimental_test.cc
+++ b/tensorflow/c/c_api_experimental_test.cc
@@ -16,8 +16,10 @@ limitations under the License.
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/c_test_util.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace tensorflow {
namespace {
@@ -116,5 +118,49 @@ TEST(CAPI_EXPERIMENTAL, ImagenetIteratorGetNext) {
TF_DeleteStatus(s);
}
+TEST(CAPI_EXPERIMENTAL, GetServerDefTest) {
+ const string expected_text_proto(R"(cluster {
+ job {
+ name: "worker"
+ tasks {
+ key: 0
+ value: "tpuserver:0"
+ }
+ tasks {
+ key: 1
+ value: "localhost:1"
+ }
+ }
+}
+job_name: "worker"
+task_index: 1
+protocol: "grpc"
+)");
+
+ TF_Status* status = TF_NewStatus();
+ TF_Buffer* result = TFE_GetServerDef(expected_text_proto.c_str(), status);
+ EXPECT_EQ(TF_GetCode(status), TF_OK);
+
+ ServerDef actual;
+ ASSERT_TRUE(actual.ParseFromArray(result->data, result->length));
+ string actual_text_proto;
+ tensorflow::protobuf::TextFormat::PrintToString(actual, &actual_text_proto);
+ EXPECT_EQ(expected_text_proto, actual_text_proto);
+
+ const string malformed_text_proto(R"(cluster {
+ job {
+ name: "worker")");
+ TF_Buffer* null_result =
+ TFE_GetServerDef(malformed_text_proto.c_str(), status);
+ EXPECT_NE(TF_GetCode(status), TF_OK);
+ EXPECT_TRUE(tensorflow::str_util::StrContains(
+ TF_Message(status), "Invalid text proto for ServerDef"));
+ EXPECT_EQ(null_result, nullptr);
+
+ // Cleanup
+ TF_DeleteBuffer(result);
+ TF_DeleteStatus(status);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 37be52f57d..3ee31a6a7a 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -68,7 +68,10 @@ tf_cuda_library(
tf_cuda_library(
name = "c_api_internal",
hdrs = ["c_api_internal.h"],
- visibility = ["//tensorflow:internal"],
+ visibility = [
+ "//learning/deepmind/courier:__pkg__",
+ "//tensorflow:internal",
+ ],
deps = [
":c_api",
"//tensorflow/c:c_api",
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 6f86ea80e5..0bf3d9542b 100755
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -375,6 +375,17 @@ int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
return result;
}
+int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
+ if (h == nullptr || h->handle == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "The passed in handle is a nullptr");
+ return -1;
+ }
+ tensorflow::int64 result;
+ status->status = h->handle->NumElements(&result);
+ return result;
+}
+
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index a87d73ec8e..6323f8a053 100755
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -163,6 +163,8 @@ TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h,
TF_Status* status);
+TF_CAPI_EXPORT extern int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h,
+ TF_Status* status);
// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
int dim_index,
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index ce038a4b57..41b5b8ff36 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -29,15 +29,8 @@ limitations under the License.
namespace tensorflow {
namespace eager {
-// Information about a tensor.
-struct TapeTensor {
- int64 id; // Expected to be unique in the lifetime of this process.
- DataType dtype;
- TensorShape shape;
-};
-
// Represents an entry in the tape.
-template <typename BackwardFunction>
+template <typename BackwardFunction, typename TapeTensor>
struct OpTapeEntry {
string op_type;
std::vector<TapeTensor> output_tensor_info;
@@ -57,8 +50,8 @@ struct OpTapeEntry {
using TensorTape = gtl::FlatMap<int64, int64>;
// Map from operation-id to tape entry.
-template <typename BackwardFunction>
-using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction>>;
+template <typename BackwardFunction, typename TapeTensor>
+using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction, TapeTensor>>;
// Operations the tape needs to perform on tensors to do backpropagation. Named
// "vspace" because a subset of these are related to a vector space, such as
@@ -79,7 +72,7 @@ using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction>>;
// TODO(apassos) provide concrete template instantiations for TFE_TensorHandle
// specialization, which is blocked by quite a few things needing to loop back
// into python now.
-template <typename Gradient, typename BackwardFunction>
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
class VSpace {
public:
virtual ~VSpace() {}
@@ -93,10 +86,10 @@ class VSpace {
gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
// Returns a tensor of the right shape and dtype filled with zeros.
- virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0;
+ virtual Gradient* Zeros(const TapeTensor& tensor) const = 0;
// Returns a Tensor which is filled with ones and like the input.
- virtual Gradient* Ones(TensorShape shape, DataType dtype) const = 0;
+ virtual Gradient* Ones(const TapeTensor& tensor) const = 0;
// Calls the passed-in backward function.
virtual Status CallBackwardFunction(
@@ -114,7 +107,7 @@ class VSpace {
// Traces the execution of operations, doing eager garbage collection, and
// exporting a full trace so other code can do backpropagation. Not thread-safe.
-template <typename Gradient, typename BackwardFunction>
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
class GradientTape {
public:
// If `persistent` is true, GradientTape will not eagerly delete backward
@@ -134,7 +127,7 @@ class GradientTape {
void Watch(int64 tensor_id);
void RecordOperation(
- const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
+ const string& op_type, std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
@@ -146,17 +139,18 @@ class GradientTape {
// once) and produces the gradient of the target tensors with respect to the
// source tensors. The output gradients are used if not empty and not
// null. The result is populated with one tensor per target element.
- Status ComputeGradient(const VSpace<Gradient, BackwardFunction>& vspace,
- gtl::ArraySlice<int64> target_tensor_ids,
- gtl::ArraySlice<int64> source_tensor_id,
- gtl::ArraySlice<Gradient*> output_gradients,
- std::vector<Gradient*>* result);
+ Status ComputeGradient(
+ const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
+ gtl::ArraySlice<int64> target_tensor_ids,
+ gtl::ArraySlice<int64> source_tensor_id,
+ gtl::ArraySlice<Gradient*> output_gradients,
+ std::vector<Gradient*>* result);
bool IsPersistent() const { return persistent_; }
private:
TensorTape tensor_tape_;
- OpTape<BackwardFunction> op_tape_;
+ OpTape<BackwardFunction, TapeTensor> op_tape_;
int64 next_op_id_{0};
// Map from tensor id to number of remaining usages (i.e. how many entries in
@@ -186,8 +180,8 @@ inline bool IsDtypeTrainable(DataType dtype) {
}
}
-template <typename Gradient, typename BackwardFunction>
-bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+bool GradientTape<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
gtl::ArraySlice<int64> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes) {
CHECK_EQ(tensor_ids.size(), dtypes.size());
@@ -201,14 +195,15 @@ bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
return false;
}
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) {
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
+ int64 tensor_id) {
tensor_tape_.emplace(tensor_id, -1);
}
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::RecordOperation(
- const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation(
+ const string& op_type, std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
@@ -229,16 +224,18 @@ void GradientTape<Gradient, BackwardFunction>::RecordOperation(
for (const TapeTensor& o : output_tensors) {
// Note: the tensor can have already been watched and hence be in the tape,
// so we cannot check that we're inserting it here.
- tensor_tape_[o.id] = op_id;
- tensor_usage_[o.id] = 1;
+ tensor_tape_[o.GetID()] = op_id;
+ tensor_usage_[o.GetID()] = 1;
tensors.push_back(o);
}
- op_tape_[op_id] = OpTapeEntry<BackwardFunction>{
- op_type, tensors, ids, backward_function, backward_function_deleter};
+ op_tape_[op_id] = OpTapeEntry<BackwardFunction, TapeTensor>{
+ op_type, std::move(tensors), ids, backward_function,
+ backward_function_deleter};
}
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) {
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+void GradientTape<Gradient, BackwardFunction, TapeTensor>::DeleteTrace(
+ int64 tensor_id) {
auto it = tensor_usage_.find(tensor_id);
if (it == tensor_usage_.end()) {
return;
@@ -261,7 +258,7 @@ void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) {
auto op_it = op_tape_.find(op_id);
CHECK(op_it != op_tape_.end());
for (const auto& output : op_it->second.output_tensor_info) {
- if (tensor_usage_.find(output.id) != tensor_usage_.end()) {
+ if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) {
// Found a usage for an output, so cannot delete the op.
return;
}
@@ -304,9 +301,9 @@ void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) {
namespace {
-template <typename BackwardFunction>
+template <typename BackwardFunction, typename TapeTensor>
struct BackpropInitialState {
- OpTape<BackwardFunction> op_tape;
+ OpTape<BackwardFunction, TapeTensor> op_tape;
// Map from tensor ID to how many references still exist for this tensor in
// the tape.
@@ -322,17 +319,17 @@ struct BackpropInitialState {
// If `persistent_tape` is false, op_tape is cleared and backwards functions
// not needed for gradient computation are deleted. Backwards functions that
// are needed, are copied and returned in BackpropInitialState.
-template <typename BackwardFunction>
-BackpropInitialState<BackwardFunction> PrepareBackprop(
+template <typename BackwardFunction, typename TapeTensor>
+BackpropInitialState<BackwardFunction, TapeTensor> PrepareBackprop(
gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
- OpTape<BackwardFunction>* op_tape, const gtl::FlatSet<int64>& sources_set,
- bool persistent_tape) {
+ OpTape<BackwardFunction, TapeTensor>* op_tape,
+ const gtl::FlatSet<int64>& sources_set, bool persistent_tape) {
std::vector<int64> tensor_stack;
tensor_stack.reserve(target.size());
for (auto t : target) {
tensor_stack.push_back(t);
}
- BackpropInitialState<BackwardFunction> result;
+ BackpropInitialState<BackwardFunction, TapeTensor> result;
while (!tensor_stack.empty()) {
int64 tensor_id = tensor_stack.back();
tensor_stack.pop_back();
@@ -383,9 +380,9 @@ BackpropInitialState<BackwardFunction> PrepareBackprop(
return result;
}
-template <typename BackwardFunction>
+template <typename BackwardFunction, typename TapeTensor>
std::vector<int64> InitialStack(
- const OpTape<BackwardFunction>& op_tape,
+ const OpTape<BackwardFunction, TapeTensor>& op_tape,
const gtl::FlatMap<int64, int64>& op_missing_tensor) {
std::vector<int64> result;
for (auto& op_entry : op_tape) {
@@ -396,13 +393,13 @@ std::vector<int64> InitialStack(
return result;
}
-template <typename Gradient, typename BackwardFunction>
-Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
- gtl::ArraySlice<int64> target_tensor_ids,
- gtl::ArraySlice<Gradient*> output_gradients,
- const TensorTape& tensor_tape,
- const OpTape<BackwardFunction>& op_tape,
- gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+Status InitialGradients(
+ const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
+ gtl::ArraySlice<int64> target_tensor_ids,
+ gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
+ const OpTape<BackwardFunction, TapeTensor>& op_tape,
+ gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
for (int i = 0; i < target_tensor_ids.size(); ++i) {
const int64 id = target_tensor_ids[i];
if (output_gradients.empty() || output_gradients[i] == nullptr) {
@@ -416,11 +413,10 @@ Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
}
bool found = false;
for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
- if (op_it->second.output_tensor_info[j].id == id) {
+ if (op_it->second.output_tensor_info[j].GetID() == id) {
found = true;
(*result)[id].push_back(
- vspace.Ones(op_it->second.output_tensor_info[j].shape,
- op_it->second.output_tensor_info[j].dtype));
+ vspace.Ones(op_it->second.output_tensor_info[j]));
break;
}
}
@@ -440,6 +436,18 @@ Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
return Status::OK();
}
+// TODO(agarwal): use an automatic mechanism for handling None arguments to
+// gradient functions.
+//
+// Some gradient functions can accept None arguments for gradients. The
+// following maps the operation name to the indices at which the corresponding
+// gradient function can accept None values. e.g. FusedBatchNorm outputs 5
+// values and hence receives 5 gradient values during backprop. However the
+// gradient function uses only the first of those values and ignores the rest.
+// The entry, "FusedBatchNorm": [1, 2, 3, 4], indicates that only the gradient
+// corresponding to index 0 is used, and the gradient values at indices 1-4 are
+// ignored (and hence can be None). The backprop algorithm can then leverage
+// this by not constructing zeros to pass for those indices.
gtl::FlatMap<string, gtl::FlatSet<int>>* FunctionsAcceptingNoneForIndicesMap() {
static auto* const m = new gtl::FlatMap<string, gtl::FlatSet<int>>({
{"SoftmaxCrossEntropyWithLogits", {1}},
@@ -457,16 +465,16 @@ gtl::FlatMap<string, gtl::FlatSet<int>>* FunctionsAcceptingNoneForIndicesMap() {
constexpr int kMinAggregateCount = 4;
constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
-template <typename Gradient, typename BackwardFunction>
-Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
- const VSpace<Gradient, BackwardFunction>& vspace,
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
+ const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
gtl::ArraySlice<int64> target_tensor_ids,
gtl::ArraySlice<int64> source_tensor_ids,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result) {
gtl::FlatSet<int64> sources_set(source_tensor_ids.begin(),
source_tensor_ids.end());
- BackpropInitialState<BackwardFunction> state = PrepareBackprop(
+ BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop(
target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
std::vector<int64> op_stack =
InitialStack(state.op_tape, state.op_missing_tensor);
@@ -510,7 +518,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
out_gradients.reserve(trace.output_tensor_info.size());
bool any_gradient_nonzero = false;
for (int i = 0; i < trace.output_tensor_info.size(); ++i) {
- const int64 id = trace.output_tensor_info[i].id;
+ const int64 id = trace.output_tensor_info[i].GetID();
auto grad_it = gradients.find(id);
if (grad_it == gradients.end()) {
auto func_name_it =
@@ -519,9 +527,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
func_name_it->second.find(i) != func_name_it->second.end()) {
out_gradients.push_back(nullptr);
} else {
- out_gradients.push_back(
- vspace.Zeros(trace.output_tensor_info[i].shape,
- trace.output_tensor_info[i].dtype));
+ out_gradients.push_back(vspace.Zeros(trace.output_tensor_info[i]));
}
} else {
any_gradient_nonzero = true;
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index 8486b585c8..247236b760 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -110,7 +110,7 @@ void ExtendSession(TF_Session* session, TF_Status* status) {
session->extend_before_run = false;
}
-std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
+std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
Node* node = &output.oper->node;
CppShapeInferenceResult::HandleData handle_data;
handle_data.set_is_set(true);
@@ -135,9 +135,8 @@ std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
return result;
}
-void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
- const void* proto, size_t proto_len,
- TF_Status* status) {
+void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
+ size_t proto_len, TF_Status* status) {
tensorflow::CppShapeInferenceResult::HandleData handle_data;
if (!handle_data.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument(
diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h
index 4bcb5bde62..5cce84020b 100644
--- a/tensorflow/c/python_api.h
+++ b/tensorflow/c/python_api.h
@@ -54,16 +54,17 @@ void SetRequireShapeInferenceFns(TF_Graph* graph, bool require);
void ExtendSession(TF_Session* session, TF_Status* status);
// Returns the serialized CppShapeInferenceResult::HandleData proto for
-// `output` if its a resource tensor, or otherwise returns the empty string.
-std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output);
+// `output` if its a resource or variant tensor, or otherwise returns the empty
+// string.
+std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output);
// Sets `output` based on `proto`, which should be a serialized
-// CppShapeInferenceResult::HandleData proto.
+// CppShapeInferenceResult::HandleData proto. `output` should be a resource
+// or variant tensor.
// NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string
// because I couldn't get SWIG to work otherwise.
-void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
- const void* proto, size_t proto_len,
- TF_Status* status);
+void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
+ size_t proto_len, TF_Status* status);
} // namespace tensorflow
#endif // TENSORFLOW_C_PYTHON_API_H_
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index f56521dac0..b587e63227 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -10,11 +10,12 @@ licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
- "tf_cc_test",
+ "cc_library_with_android_deps",
"tf_cc_binary",
+ "tf_cc_test",
"tf_copts",
"tf_gen_op_wrappers_cc",
- "cc_library_with_android_deps",
+ "transitive_hdrs",
)
cc_library(
@@ -716,3 +717,26 @@ tf_cc_test(
"//tensorflow/core:testlib",
],
)
+
+transitive_hdrs(
+ name = "headers",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ ":cc_ops",
+ ":client_session",
+ ":coordinator",
+ ":gradient_checker",
+ ":gradients",
+ ":ops",
+ ":queue_runner",
+ ":remote_fused_graph_ops",
+ ":scope",
+ "//tensorflow/cc/profiler",
+ "//tensorflow/cc/saved_model:constants",
+ "//tensorflow/cc/saved_model:loader",
+ "//tensorflow/cc/saved_model:reader",
+ "//tensorflow/cc/saved_model:signature_constants",
+ "//tensorflow/cc/saved_model:tag_constants",
+ "//tensorflow/cc/tools:freeze_saved_model",
+ ],
+)
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index 7a0932d44d..10fa33ab5e 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -25,6 +25,7 @@ test_suite(
":test_graph_tfmatmul_test",
":test_graph_tfmatmulandadd_test",
":test_graph_tfsplits_test",
+ ":test_graph_tftop_k_test",
":tfcompile_test",
],
)
@@ -42,6 +43,7 @@ py_binary(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
"//tensorflow/python:platform",
"//tensorflow/python:session",
"//tensorflow/python:training",
@@ -66,6 +68,7 @@ genrule(
"test_graph_tfmatmul.pb",
"test_graph_tfmatmulandadd.pb",
"test_graph_tfsplits.pb",
+ "test_graph_tftop_k.pb",
],
# Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any
# GPUs which might be present. This is important because builds may run
@@ -208,6 +211,17 @@ tf_library(
],
)
+tf_library(
+ name = "test_graph_tftop_k",
+ testonly = 1,
+ config = "test_graph_tftop_k.config.pbtxt",
+ cpp_class = "TopKComp",
+ graph = "test_graph_tftop_k.pb",
+ tags = [
+ "manual",
+ ],
+)
+
tf_cc_test(
name = "tfcompile_test",
srcs = ["tfcompile_test.cc"],
@@ -226,6 +240,7 @@ tf_cc_test(
":test_graph_tfmatmulandadd",
":test_graph_tfmatmulandadd_with_profiling",
":test_graph_tfsplits",
+ ":test_graph_tftop_k",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto",
diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py
index 9ec7df163b..64b861a730 100644
--- a/tensorflow/compiler/aot/tests/make_test_graphs.py
+++ b/tensorflow/compiler/aot/tests/make_test_graphs.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import app
from tensorflow.python.training import saver as saver_lib
@@ -46,7 +47,7 @@ def tfadd(_):
def tfadd_with_ckpt(out_dir):
x = array_ops.placeholder(dtypes.int32, name='x_hold')
- y = variables.Variable(constant_op.constant([0]), name='y_saved')
+ y = variables.VariableV1(constant_op.constant([0]), name='y_saved')
math_ops.add(x, y, name='x_y_sum')
init_op = variables.initialize_all_variables()
@@ -61,7 +62,7 @@ def tfadd_with_ckpt(out_dir):
def tfadd_with_ckpt_saver(out_dir):
x = array_ops.placeholder(dtypes.int32, name='x_hold')
- y = variables.Variable(constant_op.constant([0]), name='y_saved')
+ y = variables.VariableV1(constant_op.constant([0]), name='y_saved')
math_ops.add(x, y, name='x_y_sum')
init_op = variables.initialize_all_variables()
@@ -142,6 +143,12 @@ def tfsplits(_):
array_ops.identity(y, name='result')
+def tftop_k(_):
+ x = array_ops.placeholder(dtypes.int32, shape=[5], name='x')
+ output = nn_ops.top_k(x, 2, name='values')
+ array_ops.identity(output[1], name='indices')
+
+
def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
@@ -163,6 +170,7 @@ def main(_):
write_graph(tfmatmul, FLAGS.out_dir)
write_graph(tfmatmulandadd, FLAGS.out_dir)
write_graph(tfsplits, FLAGS.out_dir)
+ write_graph(tftop_k, FLAGS.out_dir)
if __name__ == '__main__':
diff --git a/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt
new file mode 100644
index 0000000000..6b4ac2d7cb
--- /dev/null
+++ b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt
@@ -0,0 +1,13 @@
+# Text form of tensorflow.tf2xla.Config proto.
+feed {
+ id { node_name: "x" }
+ shape {
+ dim { size: 5 }
+ }
+}
+fetch {
+ id { node_name: "values" }
+}
+fetch {
+ id { node_name: "indices" }
+}
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index 7ac90fb8a9..f10852c785 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h"
#include "tensorflow/compiler/xla/service/hlo_profile_printer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@@ -448,6 +449,30 @@ TEST(TFCompileTest, Splits) {
EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4);
}
+TEST(TFCompileTest, TopK) {
+ Eigen::ThreadPool tp(1);
+ Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
+
+ TopKComp fn;
+
+ fn.set_thread_pool(&device);
+ // x = [4, 1, 4, 4, 3]
+ fn.arg0(0) = 4;
+ fn.arg0(1) = 1;
+ fn.arg0(2) = 4;
+ fn.arg0(3) = 4;
+ fn.arg0(4) = 3;
+
+ EXPECT_TRUE(fn.Run());
+ EXPECT_EQ(fn.error_msg(), "");
+ const int32 expected_values[] = {4, 4};
+ const int32 expected_indices[] = {0, 2};
+ EXPECT_EQ(expected_values[0], fn.result0(0));
+ EXPECT_EQ(expected_values[1], fn.result0(1));
+ EXPECT_EQ(expected_indices[0], fn.result1(0));
+ EXPECT_EQ(expected_indices[1], fn.result1(1));
+}
+
TEST(TFCompileTest, AssertEqAndReturnDiff) {
// Assert is converted into a no-op in XLA, so there is no failure even if the
// two args are different.
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index 792b7fe14a..859c84bb91 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -273,6 +273,7 @@ def tf_library(
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
+ "//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort",
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 1001c57f3d..5bf4af1014 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -26,6 +26,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
# Target that bundles up the XLA CPU and GPU JIT devices.
cc_library(
@@ -50,7 +51,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":jit_compilation_passes",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:cpu_plugin",
],
@@ -62,7 +63,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = if_cuda([
":jit_compilation_passes",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin",
]),
@@ -76,7 +77,7 @@ cc_library(
deps = [
":jit_compilation_passes",
":xla_device",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/jit/legacy_flags:xla_device_flags",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@@ -94,7 +95,7 @@ cc_library(
deps = [
":jit_compilation_passes",
":xla_device",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
@@ -111,7 +112,7 @@ cc_library(
deps = [
":jit_compilation_passes",
":xla_device",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:interpreter_plugin", # buildcleaner: keep
@@ -280,7 +281,7 @@ cc_library(
deps = [
":common",
":compilation_passes",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@@ -341,7 +342,7 @@ tf_cc_test(
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:sendrecv_ops",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
@@ -359,7 +360,7 @@ tf_cc_test(
cc_library(
name = "compilation_passes",
srcs = [
- "build_xla_launch_ops_pass.cc",
+ "build_xla_ops_pass.cc",
"deadness_analysis.cc",
"deadness_analysis_internal.h",
"encapsulate_subgraphs_pass.cc",
@@ -369,7 +370,7 @@ cc_library(
"partially_decluster_pass.cc",
],
hdrs = [
- "build_xla_launch_ops_pass.h",
+ "build_xla_ops_pass.h",
"deadness_analysis.h",
"encapsulate_subgraphs_pass.h",
"encapsulate_xla_computations_pass.h",
@@ -459,7 +460,7 @@ tf_cc_test(
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:sendrecv_ops",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
@@ -477,6 +478,7 @@ tf_cc_test(
name = "compilation_passes_test",
size = "small",
srcs = [
+ "build_xla_ops_pass_test.cc",
"encapsulate_subgraphs_pass_test.cc",
"encapsulate_xla_computations_pass_test.cc",
"mark_for_compilation_pass_test.cc",
@@ -485,6 +487,7 @@ tf_cc_test(
deps = [
":common",
":compilation_passes",
+ ":node_matchers",
":xla_cluster_util",
":xla_gpu_device",
"//tensorflow/cc:cc_ops",
@@ -493,7 +496,7 @@ tf_cc_test(
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:sendrecv_ops",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:test_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
@@ -524,7 +527,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
@@ -628,6 +631,15 @@ tf_cc_test(
],
)
+tf_custom_op_py_library(
+ name = "xla_ops_py",
+ kernels = ["//tensorflow/compiler/jit/ops:xla_ops"],
+ visibility = [
+ ":friends",
+ ],
+ deps = ["//tensorflow/compiler/jit/ops:xla_ops_wrapper_py"],
+)
+
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",
diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc
deleted file mode 100644
index b17ff589e2..0000000000
--- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc
+++ /dev/null
@@ -1,142 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
-#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
-#include "tensorflow/compiler/tf2xla/dump_graph.h"
-#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/common_runtime/optimization_registry.h"
-#include "tensorflow/core/framework/graph_def_util.h"
-#include "tensorflow/core/framework/node_def_builder.h"
-#include "tensorflow/core/framework/node_def_util.h"
-#include "tensorflow/core/graph/algorithm.h"
-#include "tensorflow/core/graph/graph.h"
-#include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/public/version.h"
-
-namespace tensorflow {
-
-static Status BuildLaunchNode(
- const string& nodename, const string& function_name,
- const AttrValueMap& function_attr, const string& device_name,
- const DataTypeVector& constant_dtypes, int num_resources,
- const DataTypeVector& arg_dtypes, const DataTypeVector& result_dtypes,
- Graph* graph, Node** node) {
- NodeDef def;
- def.set_name(graph->NewName(nodename));
- def.set_op("XlaLaunch");
- def.set_device(device_name);
- AddNodeAttr("Tconstants", constant_dtypes, &def);
- AddNodeAttr("Targs", arg_dtypes, &def);
- AddNodeAttr("Nresources", num_resources, &def);
- AddNodeAttr("Tresults", result_dtypes, &def);
- NameAttrList function;
- function.set_name(function_name);
- *function.mutable_attr() = function_attr;
- AddNodeAttr("function", function, &def);
-
- Status status;
- *node = graph->AddNode(def, &status);
- return status;
-}
-
-static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
- VLOG(2) << "Replacing " << node->name() << " with XlaLaunch";
-
- int num_constant_args, num_resource_args;
- TF_RETURN_IF_ERROR(
- GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, &num_constant_args));
- TF_RETURN_IF_ERROR(
- GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, &num_resource_args));
-
- if (num_constant_args < 0 || num_resource_args < 0 ||
- num_constant_args + num_resource_args > node->num_inputs()) {
- return errors::InvalidArgument(
- "Invalid number of constant/resource arguments to XLA kernel.");
- }
- const int num_nonconst_args =
- node->num_inputs() - num_constant_args - num_resource_args;
-
- DataTypeVector const_dtypes(node->input_types().begin(),
- node->input_types().begin() + num_constant_args);
- DataTypeVector arg_dtypes(
- node->input_types().begin() + num_constant_args,
- node->input_types().begin() + num_constant_args + num_nonconst_args);
-
- // Build a XlaLaunch operator to execute the function body.
- Node* launch_node;
- TF_RETURN_IF_ERROR(BuildLaunchNode(
- graph->NewName(node->name()), node->type_string(), node->def().attr(),
- node->requested_device(), const_dtypes, num_resource_args, arg_dtypes,
- node->output_types(), graph, &launch_node));
- launch_node->set_assigned_device_name(node->assigned_device_name());
-
- // Copy incoming edges to the launch node.
- for (const Edge* edge : node->in_edges()) {
- if (edge->IsControlEdge()) {
- graph->AddControlEdge(edge->src(), launch_node);
- } else {
- graph->AddEdge(edge->src(), edge->src_output(), launch_node,
- edge->dst_input());
- }
- }
-
- // Copy outgoing edges to the launch node.
- std::vector<const Edge*> out_edges(node->out_edges().begin(),
- node->out_edges().end());
- for (const Edge* edge : out_edges) {
- Node* dst = edge->dst();
- int src_output = edge->src_output();
- int dst_input = edge->dst_input();
- graph->RemoveEdge(edge);
-
- if (edge->IsControlEdge()) {
- graph->AddControlEdge(launch_node, dst);
- } else {
- graph->AddEdge(launch_node, src_output, dst, dst_input);
- }
- }
- graph->RemoveNode(node);
-
- return Status::OK();
-}
-
-Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) {
- Graph* graph = options.graph->get();
-
- for (Node* n : graph->op_nodes()) {
- // In all cases, only try to compile computational nodes.
- if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
- continue;
- }
-
- // Only compile nodes that are marked for compilation by the
- // compilation-marking pass (via 'attr_name').
- if (IsXlaCompiledKernel(*n)) {
- TF_RETURN_IF_ERROR(ReplaceNodeWithXlaLaunch(graph, n));
- }
- }
-
- if (VLOG_IS_ON(1)) {
- dump_graph::DumpGraphToFile("build_xla_launch_ops", *graph,
- options.flib_def);
- }
- return Status::OK();
-}
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc
new file mode 100644
index 0000000000..9e3fd93cda
--- /dev/null
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc
@@ -0,0 +1,182 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+static Status BuildXlaCompileNode(
+ const string& nodename, const string& function_name,
+ const AttrValueMap& function_attr, const string& device_name,
+ const DataTypeVector& constant_dtypes, int num_resources,
+ const DataTypeVector& arg_dtypes, Graph* graph, Node** node) {
+ NodeDef def;
+ def.set_name(graph->NewName(nodename));
+ def.set_op("_XlaCompile");
+ def.set_device(device_name);
+ AddNodeAttr("Tconstants", constant_dtypes, &def);
+ AddNodeAttr("Targs", arg_dtypes, &def);
+ AddNodeAttr("Nresources", num_resources, &def);
+ NameAttrList function;
+ function.set_name(function_name);
+ *function.mutable_attr() = function_attr;
+ AddNodeAttr("function", function, &def);
+
+ Status status;
+ *node = graph->AddNode(def, &status);
+ return status;
+}
+
+static Status BuildXlaRunNode(const string& nodename, const string& device_name,
+ const DataTypeVector& arg_dtypes,
+ const DataTypeVector& result_dtypes, Graph* graph,
+ Node** node) {
+ NodeDef def;
+ def.set_name(graph->NewName(nodename));
+ def.set_op("_XlaRun");
+ def.set_device(device_name);
+ AddNodeAttr("Targs", arg_dtypes, &def);
+ AddNodeAttr("Tresults", result_dtypes, &def);
+
+ Status status;
+ *node = graph->AddNode(def, &status);
+ return status;
+}
+
+static Status GetXlaAttrs(Node* node, int* num_constant_args,
+ int* num_resource_args, DataTypeVector* const_dtypes,
+ DataTypeVector* arg_dtypes) {
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, num_constant_args));
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, num_resource_args));
+
+ if (*num_constant_args < 0 || *num_resource_args < 0 ||
+ *num_constant_args + *num_resource_args > node->num_inputs()) {
+ return errors::InvalidArgument(
+ "Invalid number of constant/resource arguments to XLA kernel.");
+ }
+
+ const int num_nonconst_args =
+ node->num_inputs() - *num_constant_args - *num_resource_args;
+
+ const DataTypeVector& input_types = node->input_types();
+ std::copy(input_types.begin(), input_types.begin() + *num_constant_args,
+ std::back_inserter(*const_dtypes));
+ std::copy(input_types.begin() + *num_constant_args,
+ input_types.begin() + *num_constant_args + num_nonconst_args,
+ std::back_inserter(*arg_dtypes));
+ return Status::OK();
+}
+
+static void CopyIncomingEdges(Graph* g, Node* old_node, Node* new_node,
+ int prefix_to_ignore) {
+ for (const Edge* edge : old_node->in_edges()) {
+ if (edge->IsControlEdge()) {
+ g->AddControlEdge(edge->src(), new_node);
+ } else if (edge->dst_input() >= prefix_to_ignore) {
+ g->AddEdge(edge->src(), edge->src_output(), new_node,
+ edge->dst_input() - prefix_to_ignore);
+ }
+ }
+}
+
+static void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
+ std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
+ old_node->out_edges().end());
+ for (const Edge* edge : out_edges) {
+ // TODO(sanjoy): This does not update NodeDef inputs.
+ g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input());
+ g->RemoveEdge(edge);
+ }
+}
+
+static Status ReplaceNodeWithXlaCompileAndRun(Graph* g, Node* n) {
+ int num_constant_args, num_resource_args;
+ DataTypeVector const_dtypes;
+ DataTypeVector arg_dtypes;
+
+ TF_RETURN_IF_ERROR(GetXlaAttrs(n, &num_constant_args, &num_resource_args,
+ &const_dtypes, &arg_dtypes));
+
+ Node *compile_node, *run_node;
+
+ TF_RETURN_IF_ERROR(BuildXlaCompileNode(
+ n->name(), n->type_string(), n->def().attr(), n->requested_device(),
+ const_dtypes, num_resource_args, arg_dtypes, g, &compile_node));
+
+ DataTypeVector arg_dtypes_with_resources = arg_dtypes;
+ for (int i = 0; i < num_resource_args; i++) {
+ arg_dtypes_with_resources.push_back(DT_RESOURCE);
+ }
+
+ TF_RETURN_IF_ERROR(BuildXlaRunNode(n->name(), n->requested_device(),
+ arg_dtypes_with_resources,
+ n->output_types(), g, &run_node));
+
+ compile_node->set_assigned_device_name(n->assigned_device_name());
+ run_node->set_assigned_device_name(n->assigned_device_name());
+
+ CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/compile_node,
+ /*prefix_to_ignore=*/0);
+ CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/run_node,
+ /*prefix_to_ignore=*/num_constant_args);
+
+ // The compilation_key output.
+ g->AddEdge(compile_node, 0, run_node, n->num_inputs() - num_constant_args);
+
+ MoveOutgoingEdges(g, /*old_node=*/n, /*new_node=*/run_node);
+ g->RemoveNode(n);
+
+ return Status::OK();
+}
+
+Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
+ Graph* graph = options.graph->get();
+
+ for (Node* n : graph->op_nodes()) {
+ // In all cases, only try to compile computational nodes.
+ if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
+ continue;
+ }
+
+ // Only compile nodes that are marked for compilation by the
+ // compilation-marking pass (via 'attr_name').
+ if (IsXlaCompiledKernel(*n)) {
+ TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndRun(graph, n));
+ }
+ }
+
+ if (VLOG_IS_ON(1)) {
+ dump_graph::DumpGraphToFile("build_xla_ops", *graph, options.flib_def);
+ }
+ return Status::OK();
+}
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.h b/tensorflow/compiler/jit/build_xla_ops_pass.h
index 1dfea93f02..1dd38fa951 100644
--- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.h
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.h
@@ -13,19 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
-#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
+#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_
+#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
-class BuildXlaLaunchOpsPass : public GraphOptimizationPass {
+// Adds _XlaCompile and _XlaRun operations to the TF graph that compiles and
+// executes (using XLA) TF function calls marked with "_XlaCompiledKernel".
+class BuildXlaOpsPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
+#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
new file mode 100644
index 0000000000..b7cb4506b9
--- /dev/null
+++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
@@ -0,0 +1,112 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/jit/node_matchers.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+using ::tensorflow::testing::FindNodeByName;
+using ::tensorflow::testing::matchers::CtrlDeps;
+using ::tensorflow::testing::matchers::NodeWith;
+using ::tensorflow::testing::matchers::Op;
+
+Status BuildXlaOps(const Scope& s, std::unique_ptr<Graph>* result) {
+ auto graph = absl::make_unique<Graph>(OpRegistry::Global());
+ TF_RETURN_IF_ERROR(s.ToGraph(graph.get()));
+
+ // Assign all nodes to the CPU device.
+ static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
+ for (Node* n : graph->nodes()) {
+ if (n->assigned_device_name().empty()) {
+ n->set_assigned_device_name(kCpuDevice);
+ }
+ }
+
+ GraphOptimizationPassOptions opt_options;
+ opt_options.graph = &graph;
+ BuildXlaOpsPass pass;
+ TF_RETURN_IF_ERROR(pass.Run(opt_options));
+ *result = std::move(graph);
+ return Status::OK();
+}
+
+Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name,
+ const string& node_name, Node** result) {
+ NodeDef call_node;
+ call_node.set_name(node_name);
+ call_node.set_op(callee_name);
+ AddNodeAttr(kXlaCompiledKernelAttr, true, &call_node);
+ AddNodeAttr(kXlaNumConstantArgsAttr, 0, &call_node);
+ AddNodeAttr(kXlaNumResourceArgsAttr, 0, &call_node);
+ Status s;
+ *result = graph->AddNode(call_node, &s);
+ return s;
+}
+
+Node* MakeWrite(const Scope& scope, const string& id) {
+ Output var_handle =
+ ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
+ Output value_to_write =
+ ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f);
+ ops::AssignVariableOp assign_op(scope.WithOpName("Assignee" + id), var_handle,
+ value_to_write);
+ return assign_op.operation.node();
+}
+
+FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) {
+ FunctionDefLibrary flib_def;
+ FunctionDef func = FunctionDefHelper::Create(
+ /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"},
+ /*attr_def*/
+ {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)},
+ /*ret_def=*/{{"out", "out:output:0"}});
+ *flib_def.add_function() = std::move(func);
+ return flib_def;
+}
+
+TEST(BuildXlaOps, ControlDepsPreserved) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("cluster_0");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+ Node* call;
+ TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
+ Node* write_op = MakeWrite(root, "write");
+ root.graph()->AddControlEdge(call, write_op);
+
+ std::unique_ptr<Graph> graph;
+ TF_ASSERT_OK(BuildXlaOps(root, &graph));
+
+ Node* write_op_new = FindNodeByName(graph.get(), write_op->name());
+ ASSERT_NE(write_op_new, nullptr);
+ EXPECT_THAT(write_op_new, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")))));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc
index 56b034a30b..6f1ff85f24 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
index 3770eea6d0..085c0e5adb 100644
--- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
+++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
+#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
@@ -55,6 +55,6 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
// Must run after EncapsulateSubgraphsPass.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
- BuildXlaLaunchOpsPass);
+ BuildXlaOpsPass);
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 253a5d2547..0839f1cb3d 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -7,9 +7,9 @@ package(
)
cc_library(
- name = "xla_launch_op",
- srcs = ["xla_launch_op.cc"],
- hdrs = ["xla_launch_op.h"],
+ name = "xla_ops",
+ srcs = ["xla_ops.cc"],
+ hdrs = ["xla_ops.h"],
deps = [
"//tensorflow/compiler/jit:common",
"//tensorflow/compiler/jit:xla_compilation_cache",
@@ -26,6 +26,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/kernels:variable_ops",
+ "@com_google_absl//absl/memory",
],
alwayslink = 1,
)
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
deleted file mode 100644
index b6f2f632f7..0000000000
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ /dev/null
@@ -1,276 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
-
-#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/xla_launch_util.h"
-#include "tensorflow/compiler/tf2xla/shape_util.h"
-#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
-#include "tensorflow/compiler/tf2xla/xla_compiler.h"
-#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/common_runtime/dma_helper.h"
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/framework/allocator.h"
-#include "tensorflow/core/framework/node_def_util.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/kernels/variable_ops.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/stream_executor_no_cuda.h"
-#include "tensorflow/core/util/stream_executor_util.h"
-
-namespace tensorflow {
-
-XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
- const std::vector<int>& constants,
- const std::vector<int>& resources,
- const NameAttrList& function)
- : OpKernel(ctx),
- constants_(constants),
- resources_(resources),
- device_type_(ctx->device_type()),
- function_(function) {
- if (device_type_ == DeviceType(DEVICE_CPU)) {
- platform_id_ = se::host::kHostPlatformId;
- } else if (device_type_ == DeviceType(DEVICE_GPU)) {
- platform_id_ = ctx->device()
- ->tensorflow_gpu_device_info()
- ->stream->parent()
- ->platform()
- ->id();
- } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata_).ok()) {
- use_multiple_streams_ = xla_device_metadata_->UseMultipleStreams();
- platform_id_ = xla_device_metadata_->platform()->id();
- }
-}
-
-Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx,
- XlaCompilationCache** cache) {
- if (xla_device_metadata_) {
- *cache = new XlaCompilationCache(xla_device_metadata_->client(),
- xla_device_metadata_->jit_device_type());
- return Status::OK();
- }
-
- auto platform = se::MultiPlatformManager::PlatformWithId(platform_id_);
- if (!platform.ok()) {
- return platform.status();
- }
- xla::LocalClientOptions client_options;
- client_options.set_platform(platform.ValueOrDie());
- client_options.set_intra_op_parallelism_threads(
- ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
- auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
- if (!client.ok()) {
- return client.status();
- }
- const XlaOpRegistry::DeviceRegistration* registration;
- if (!XlaOpRegistry::GetCompilationDevice(device_type_.type(),
- &registration)) {
- return errors::InvalidArgument("No JIT device registered for ",
- device_type_.type());
- }
- *cache = new XlaCompilationCache(
- client.ValueOrDie(), DeviceType(registration->compilation_device_name));
- return Status::OK();
-}
-
-void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
- VLOG(1) << "XlaLocalLaunchOpBase::Compute "
- << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
- // We store information about the JIT-compiled XLA computation
- // in the ResourceMgr.
- ResourceMgr* rm = ctx->resource_manager();
- OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
-
- se::Stream* stream =
- ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
-
- XlaCompilationCache* cache;
- OP_REQUIRES_OK(ctx, rm->LookupOrCreate<XlaCompilationCache>(
- rm->default_container(), "xla_cache", &cache,
- [this, ctx](XlaCompilationCache** cache) {
- return BuildCompilationCache(ctx, cache);
- }));
- // Hold the reference to the JIT during evaluation. (We could probably
- // free it sooner because the ResourceMgr will retain a reference, but
- // this is more obviously correct.)
- core::ScopedUnref cache_ref(cache);
-
- std::map<int, OptionalTensor> variables =
- SnapshotResourceVariables(ctx, resources_);
-
- xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
-
- XlaAllocator local_xla_allocator(client->backend().platform(),
- ctx->device()->GetAllocator({}));
- xla::DeviceMemoryAllocator* xla_allocator;
- // If we are on an XlaDevice, use the underlying XLA platform's allocator
- // directly. We could use the StreamExecutor's allocator which may
- // theoretically be more correct, but XLA returns a nice OOM message in a
- // Status and StreamExecutor does not.
- //
- // Importantly we can't use ctx->device()->GetAllocator() as the allocator
- // (which local_xla_allocator above uses) as on an XlaDevice, this is a
- // dummy allocator that returns XlaTensor objects. The XlaCompiler needs a
- // real allocator to allocate real buffers.
- if (xla_device_metadata_) {
- xla_allocator = client->backend().memory_allocator();
- } else {
- xla_allocator = &local_xla_allocator;
- }
-
- XlaCompiler::Options options;
- options.client = client;
- if (ctx->op_device_context() != nullptr) {
- options.device_ordinal =
- ctx->op_device_context()->stream()->parent()->device_ordinal();
- }
- options.device_type = cache->device_type();
- options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
- options.graph_def_version = ctx->function_library()->graph_def_version();
- options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
- options.device_allocator = xla_allocator;
- if (xla_device_metadata_) {
- options.shape_representation_fn =
- xla_device_metadata_->shape_representation_fn();
- }
-
- const XlaCompiler::CompilationResult* kernel;
- xla::LocalExecutable* executable;
-
- std::map<int, Tensor> constant_args;
- for (int i : constants_) {
- constant_args.insert({i, ctx->input(i)});
- }
- XlaCompiler::CompileOptions compile_options;
- compile_options.is_entry_computation = true;
- // If we resolve constants we never emit them on the device, meaning that if
- // they are needed by a following computation the host has to transfer
- // them. Not resolving constants is expected to be faster than resolving
- // constants.
- compile_options.resolve_compile_time_constants = true;
- // Optimization: where possible, have the computation return a naked array
- // rather than a one-element tuple.
- compile_options.always_return_tuple = false;
-
- OP_REQUIRES_OK(
- ctx, cache->Compile(options, function_, constant_args, variables, ctx,
- &kernel, &executable, compile_options));
-
- VLOG(1) << "Executing XLA Computation...";
-
- XlaComputationLaunchContext launch_context(
- client, xla_allocator,
- /*allocate_xla_tensors=*/xla_device_metadata_ != nullptr,
- use_multiple_streams_);
- launch_context.PopulateInputs(ctx, kernel, variables);
-
- // Execute the computation.
- VLOG(2) << "Executing computation.";
- xla::ExecutableRunOptions run_options;
- run_options.set_stream(stream);
- run_options.set_allocator(xla_allocator);
- run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
- run_options.set_rng_seed(GetXLARandomSeed());
- Env* env = Env::Default();
- auto start_time = env->NowMicros();
-
- auto run_result = executable->Run(launch_context.arguments(), run_options);
- OP_REQUIRES(ctx, run_result.ok(), run_result.status());
-
- auto elapsed = env->NowMicros() - start_time;
- VLOG(2) << "Elapsed time: " << elapsed << "us";
-
- OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
- ctx, kernel, run_result.ConsumeValueOrDie()));
- VLOG(1) << "Done";
-}
-
-namespace {
-
-// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
-// in error case, it returns RET instead of void.
-#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
- do { \
- ::tensorflow::Status _s(__VA_ARGS__); \
- if (!TF_PREDICT_TRUE(_s.ok())) { \
- (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
- return RET; \
- } \
- } while (0)
-
-// Helper static functions to construct parameters for
-// XlaLocalLaunchBase constructor from OpKernelConstruction.
-std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
- DataTypeVector constant_types;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Tconstants", &constant_types));
- std::vector<int> constants(constant_types.size());
- std::iota(constants.begin(), constants.end(), 0);
- return constants;
-}
-
-std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
- DataTypeVector constant_types;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Tconstants", &constant_types));
-
- DataTypeVector arg_types;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Targs", &arg_types));
-
- int num_resources;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Nresources", &num_resources));
-
- std::vector<int> resources(num_resources);
- std::iota(resources.begin(), resources.end(),
- constant_types.size() + arg_types.size());
- return resources;
-}
-
-NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
- const NameAttrList* func;
- OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
- return *func;
-}
-
-#undef OP_REQUIRES_OK_RETURN
-} // namespace
-
-XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
- : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
- FunctionAttr(ctx)) {}
-
-XlaLocalLaunchOp::~XlaLocalLaunchOp() {
- VLOG(1) << "XlaLocalLaunchOp destroyed";
-}
-
-REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
-
-REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
- .Device(DEVICE_GPU)
- .HostMemory("constants")
- .HostMemory("resources"),
- XlaLocalLaunchOp);
-
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h
deleted file mode 100644
index e0f10e9817..0000000000
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.h
+++ /dev/null
@@ -1,87 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
-#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
-
-#include "tensorflow/compiler/jit/xla_compilation_cache.h"
-#include "tensorflow/compiler/jit/xla_device.h"
-#include "tensorflow/core/framework/allocator.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/util/stream_executor_util.h"
-
-namespace tensorflow {
-
-// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
-// The only difference is that it does not require arguments to follow
-// the "constants, then regular args, then resources" order.
-// It takes vectors of constant and resource arguments explicitly.
-// It does not have corresponding OpDef because it is never present
-// in the GraphDef.
-// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
-// this kernel when asked to create a kernel for an XLA-compiled function.
-class XlaLocalLaunchBase : public OpKernel {
- public:
- XlaLocalLaunchBase(OpKernelConstruction* ctx,
- const std::vector<int>& constants,
- const std::vector<int>& resources,
- const NameAttrList& function);
- XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
- XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
- ~XlaLocalLaunchBase() override = default;
-
- void Compute(OpKernelContext* ctx) override;
-
- protected:
- // Builds a XlaCompilationCache class suitable for the current device.
- Status BuildCompilationCache(OpKernelContext* ctx,
- XlaCompilationCache** cache);
-
- // Indexes of compile-time constant inputs
- std::vector<int> constants_;
- // Indexes of resource inputs
- std::vector<int> resources_;
-
- DeviceType device_type_;
- NameAttrList function_;
- se::Platform::Id platform_id_ = nullptr;
- bool use_multiple_streams_ = false;
- const XlaDevice::Metadata* xla_device_metadata_ = nullptr;
-};
-
-// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
-// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
-// responsible for handling interactions with the TensorFlow executor.
-// Once all inputs are present, and their shapes are known, the op can
-// use a 'XlaCompilationCache' to compile and execute code which is specific
-// to the shapes of input Tensors.
-// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
-// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
-// memory.
-class XlaLocalLaunchOp : public XlaLocalLaunchBase {
- public:
- explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
- ~XlaLocalLaunchOp() override;
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
-};
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc
new file mode 100644
index 0000000000..a85006eb03
--- /dev/null
+++ b/tensorflow/compiler/jit/kernels/xla_ops.cc
@@ -0,0 +1,499 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
+
+#include "absl/memory/memory.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/variable_ops.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/util/stream_executor_util.h"
+
+namespace tensorflow {
+
+namespace {
+
+Status PlatformInfoFromContext(OpKernelConstruction* ctx,
+ XlaPlatformInfo* result) {
+ DeviceType device_type = ctx->device_type();
+ se::Platform::Id platform_id = nullptr;
+ const XlaDevice::Metadata* xla_device_metadata = nullptr;
+ std::unique_ptr<XlaAllocator> xla_allocator;
+ xla::DeviceMemoryAllocator* device_allocator = nullptr;
+
+ if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
+ platform_id = se::host::kHostPlatformId;
+ } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
+ platform_id = ctx->device()
+ ->tensorflow_gpu_device_info()
+ ->stream->parent()
+ ->platform()
+ ->id();
+ } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
+ // If we are on an XlaDevice, use the underlying XLA platform's allocator
+ // directly. We could use the StreamExecutor's allocator which may
+ // theoretically be more correct, but XLA returns a nice OOM message in a
+ // Status and StreamExecutor does not.
+ //
+ // Importantly we can't use ctx->device()->GetAllocator() as the allocator
+ // (which xla_allocator above uses) as on an XlaDevice, this is a dummy
+ // allocator that returns XlaTensor objects. The XlaCompiler needs a real
+ // allocator to allocate real buffers.
+
+ platform_id = xla_device_metadata->platform()->id();
+ device_allocator =
+ xla_device_metadata->client()->backend().memory_allocator();
+ }
+
+ if (!device_allocator) {
+ TF_ASSIGN_OR_RETURN(se::Platform* const platform,
+ se::MultiPlatformManager::PlatformWithId(platform_id));
+ xla_allocator = absl::make_unique<XlaAllocator>(
+ platform, ctx->device()->GetAllocator({}));
+ }
+
+ *result = XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
+ std::move(xla_allocator), device_allocator);
+
+ return Status::OK();
+}
+
+// A closure describing how to run a compiled version of a TensorFlow function.
+//
+// It may seem unusual to stick the resource variable snapshots in this class.
+// This is necessary: we need to use the snapshots observed by the compiler as
+// the initial values for the resource variables (and cannot snapshot them again
+// during execution) because otherwise we risk observing a different snapshot
+// with shapes different from what we compiled for.
+class XlaExecutableClosure {
+ public:
+ explicit XlaExecutableClosure(
+ xla::LocalClient* client, xla::LocalExecutable* executable,
+ const XlaCompiler::CompilationResult* compilation_result,
+ std::map<int, OptionalTensor> resource_var_snapshots,
+ int num_constant_args)
+ : client_(client),
+ executable_(executable),
+ compilation_result_(compilation_result),
+ resource_var_snapshots_(std::move(resource_var_snapshots)),
+ num_constant_args_(num_constant_args) {}
+
+ XlaExecutableClosure(XlaExecutableClosure&&) = default;
+ XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default;
+
+ xla::LocalClient* client() const { return client_; }
+ xla::LocalExecutable* executable() const { return executable_; }
+ const XlaCompiler::CompilationResult* compilation_result() const {
+ return compilation_result_;
+ }
+ const std::map<int, OptionalTensor>& resource_var_snapshots() const {
+ return resource_var_snapshots_;
+ }
+ int num_constant_args() const { return num_constant_args_; }
+
+ private:
+ xla::LocalClient* client_;
+ xla::LocalExecutable* executable_;
+ const XlaCompiler::CompilationResult* compilation_result_;
+ std::map<int, OptionalTensor> resource_var_snapshots_;
+ int num_constant_args_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
+};
+
+// This maintains a mapping from a globally unique ID to XlaExecutableClosure
+// instances.
+class XlaExecutableClosureStore {
+ public:
+ XlaExecutableClosureStore() : key_counter_(0) {}
+
+ using KeyT = string;
+
+ KeyT Produce(XlaExecutableClosure result) {
+ mutex_lock l(mutex_);
+ KeyT key = absl::StrCat(key_counter_++);
+ bool insert_successful = closures_.emplace(key, std::move(result)).second;
+ DCHECK(insert_successful);
+ (void)insert_successful;
+ return key;
+ }
+
+ XlaExecutableClosure Consume(const KeyT& key) {
+ mutex_lock l(mutex_);
+ auto it = closures_.find(key);
+ DCHECK(it != closures_.end());
+ XlaExecutableClosure value = std::move(it->second);
+ closures_.erase(it);
+ return value;
+ }
+
+ static XlaExecutableClosureStore* Global() {
+ static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore;
+ return instance;
+ }
+
+ private:
+ mutex mutex_;
+ int64 key_counter_ GUARDED_BY(mutex_);
+ gtl::FlatMap<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
+};
+
+} // namespace
+
+XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
+ const std::vector<int>& constants,
+ const std::vector<int>& resources,
+ const NameAttrList& function)
+ : OpKernel(ctx),
+ constants_(constants),
+ resources_(resources),
+ function_(function) {
+ OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
+}
+
+static Status BuildCompilationCache(OpKernelContext* ctx,
+ const XlaPlatformInfo& platform_info,
+ XlaCompilationCache** cache) {
+ if (platform_info.xla_device_metadata()) {
+ *cache = new XlaCompilationCache(
+ platform_info.xla_device_metadata()->client(),
+ platform_info.xla_device_metadata()->jit_device_type());
+ return Status::OK();
+ }
+
+ auto platform =
+ se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
+ if (!platform.ok()) {
+ return platform.status();
+ }
+ xla::LocalClientOptions client_options;
+ client_options.set_platform(platform.ValueOrDie());
+ client_options.set_intra_op_parallelism_threads(
+ ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
+ auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
+ if (!client.ok()) {
+ return client.status();
+ }
+ const XlaOpRegistry::DeviceRegistration* registration;
+ if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
+ &registration)) {
+ return errors::InvalidArgument("No JIT device registered for ",
+ platform_info.device_type().type());
+ }
+ *cache = new XlaCompilationCache(
+ client.ValueOrDie(), DeviceType(registration->compilation_device_name));
+ return Status::OK();
+}
+
+static Status CompileToLocalExecutable(
+ OpKernelContext* ctx, const NameAttrList& function,
+ const XlaPlatformInfo& platform_info, absl::Span<const int> resources,
+ absl::Span<const int> constants, xla::LocalClient** client,
+ std::map<int, OptionalTensor>* variables,
+ const XlaCompiler::CompilationResult** kernel,
+ xla::LocalExecutable** executable) {
+ // We store information about the JIT-compiled XLA computation
+ // in the ResourceMgr.
+ ResourceMgr* rm = ctx->resource_manager();
+ if (!rm) {
+ return errors::Internal("No resource manager.");
+ }
+
+ XlaCompilationCache* cache;
+ TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
+ rm->default_container(), "xla_cache", &cache,
+ [&](XlaCompilationCache** cache) {
+ return BuildCompilationCache(ctx, platform_info, cache);
+ }));
+ // Hold the reference to the JIT during evaluation. (We could probably
+ // free it sooner because the ResourceMgr will retain a reference, but
+ // this is more obviously correct.)
+ core::ScopedUnref cache_ref(cache);
+
+ *variables = SnapshotResourceVariables(ctx, resources);
+ *client = static_cast<xla::LocalClient*>(cache->client());
+
+ XlaCompiler::Options options;
+ options.client = *client;
+ if (ctx->op_device_context() != nullptr) {
+ options.device_ordinal =
+ ctx->op_device_context()->stream()->parent()->device_ordinal();
+ }
+ options.device_type = cache->device_type();
+ options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
+ options.graph_def_version = ctx->function_library()->graph_def_version();
+ options.allow_cpu_custom_calls =
+ (platform_info.platform_id() == se::host::kHostPlatformId);
+ options.device_allocator = platform_info.allocator();
+ if (platform_info.xla_device_metadata()) {
+ options.shape_representation_fn =
+ platform_info.xla_device_metadata()->shape_representation_fn();
+ }
+
+ std::map<int, Tensor> constant_args;
+ for (int i : constants) {
+ constant_args.insert({i, ctx->input(i)});
+ }
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.is_entry_computation = true;
+ // If we resolve constants we never emit them on the device, meaning that if
+ // they are needed by a following computation the host has to transfer
+ // them. Not resolving constants is expected to be faster than resolving
+ // constants.
+ compile_options.resolve_compile_time_constants = true;
+ // Optimization: where possible, have the computation return a naked array
+ // rather than a one-element tuple.
+ compile_options.always_return_tuple = false;
+
+ return cache->Compile(options, function, constant_args, *variables, ctx,
+ kernel, executable, compile_options);
+}
+
+void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "XlaLocalLaunchOpBase::Compute "
+ << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
+
+ xla::LocalClient* client;
+ const XlaCompiler::CompilationResult* kernel;
+ xla::LocalExecutable* executable;
+ std::map<int, OptionalTensor> variables;
+
+ OP_REQUIRES_OK(
+ ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_,
+ constants_, &client, &variables, &kernel,
+ &executable));
+
+ se::Stream* stream =
+ ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
+
+ VLOG(1) << "Executing XLA Computation...";
+
+ XlaComputationLaunchContext launch_context(
+ client, platform_info_.allocator(),
+ /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
+ platform_info_.UseMultipleStreams());
+ launch_context.PopulateInputs(ctx, kernel, variables,
+ /*missing_ctx_input_prefix=*/0);
+
+ // Execute the computation.
+ VLOG(2) << "Executing computation.";
+ xla::ExecutableRunOptions run_options;
+ run_options.set_stream(stream);
+ run_options.set_allocator(platform_info_.allocator());
+ run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
+ run_options.set_rng_seed(GetXLARandomSeed());
+ Env* env = Env::Default();
+ auto start_time = env->NowMicros();
+
+ auto run_result = executable->Run(launch_context.arguments(), run_options);
+ OP_REQUIRES(ctx, run_result.ok(), run_result.status());
+
+ auto elapsed = env->NowMicros() - start_time;
+ VLOG(2) << "Elapsed time: " << elapsed << "us";
+
+ OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
+ ctx, kernel, run_result.ConsumeValueOrDie(),
+ /*missing_ctx_input_prefix=*/0));
+ VLOG(1) << "Done";
+}
+
+namespace {
+
+// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
+// in error case, it returns RET instead of void.
+#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
+ do { \
+ ::tensorflow::Status _s(__VA_ARGS__); \
+ if (!TF_PREDICT_TRUE(_s.ok())) { \
+ (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
+ return RET; \
+ } \
+ } while (0)
+
+// Helper static functions to construct parameters for
+// XlaLocalLaunchBase constructor from OpKernelConstruction.
+std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Tconstants", &constant_types));
+ std::vector<int> constants(constant_types.size());
+ std::iota(constants.begin(), constants.end(), 0);
+ return constants;
+}
+
+std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Tconstants", &constant_types));
+
+ DataTypeVector arg_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Targs", &arg_types));
+
+ int num_resources;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Nresources", &num_resources));
+
+ std::vector<int> resources(num_resources);
+ std::iota(resources.begin(), resources.end(),
+ constant_types.size() + arg_types.size());
+ return resources;
+}
+
+NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
+ const NameAttrList* func;
+ OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
+ return *func;
+}
+
+#undef OP_REQUIRES_OK_RETURN
+} // namespace
+
+XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
+ : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
+ FunctionAttr(ctx)) {}
+
+XlaLocalLaunchOp::~XlaLocalLaunchOp() {
+ VLOG(1) << "XlaLocalLaunchOp destroyed";
+}
+
+XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx),
+ constants_(ConstantsVector(ctx)),
+ resources_(ResourcesVector(ctx)),
+ function_(FunctionAttr(ctx)) {
+ OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
+}
+
+void XlaCompileOp::Compute(OpKernelContext* ctx) {
+ xla::LocalClient* client;
+ const XlaCompiler::CompilationResult* kernel;
+ xla::LocalExecutable* executable;
+ std::map<int, OptionalTensor> variables;
+
+ OP_REQUIRES_OK(
+ ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_,
+ constants_, &client, &variables, &kernel,
+ &executable));
+
+ // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even
+ // if it didn't have to compile the cluster because of a compilation-cache
+ // hit. This is because we at least need new snapshots of the resource
+ // variables.
+ XlaExecutableClosureStore::KeyT key =
+ XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
+ client, executable, kernel, std::move(variables), constants_.size()));
+
+ Allocator* cpu_allocator = [&] {
+ AllocatorAttributes host_alloc_attrs;
+ host_alloc_attrs.set_gpu_compatible(true);
+ host_alloc_attrs.set_on_host(true);
+ return ctx->device()->GetAllocator(host_alloc_attrs);
+ }();
+
+ Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
+ compilation_key.flat<string>()(0) = key;
+
+ Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
+ compilation_successful.flat<bool>()(0) = true;
+
+ ctx->set_output(0, compilation_key);
+ ctx->set_output(1, compilation_successful);
+}
+
+XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
+}
+
+void XlaRunOp::Compute(OpKernelContext* ctx) {
+ Tensor key_tensor = ctx->input(ctx->num_inputs() - 1);
+ const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<string>()(0);
+
+ XlaExecutableClosure closure =
+ XlaExecutableClosureStore::Global()->Consume(key);
+
+ XlaComputationLaunchContext launch_context(
+ closure.client(), platform_info_.allocator(),
+ /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
+ /*use_multiple_streams=*/platform_info_.UseMultipleStreams());
+
+ // We're missing the must-be-constant inputs, tell `PopulateInputs`
+ // about this. We don't actually need these inputs because they've
+ // already been baked into the compiled kernel.
+ launch_context.PopulateInputs(
+ ctx, closure.compilation_result(), closure.resource_var_snapshots(),
+ /*missing_ctx_input_prefix=*/closure.num_constant_args());
+
+ se::Stream* stream =
+ ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
+ xla::ExecutableRunOptions run_options;
+ run_options.set_stream(stream);
+ run_options.set_allocator(platform_info_.allocator());
+ run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
+ run_options.set_rng_seed(GetXLARandomSeed());
+ Env* env = Env::Default();
+ auto start_time = env->NowMicros();
+
+ auto run_result =
+ closure.executable()->Run(launch_context.arguments(), run_options);
+ OP_REQUIRES(ctx, run_result.ok(), run_result.status());
+
+ auto elapsed = env->NowMicros() - start_time;
+ VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
+
+ OP_REQUIRES_OK(
+ ctx,
+ launch_context.PopulateOutputs(
+ ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(),
+ /*missing_ctx_input_prefix=*/closure.num_constant_args()));
+}
+
+REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
+
+REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
+ .Device(DEVICE_GPU)
+ .HostMemory("constants")
+ .HostMemory("resources"),
+ XlaLocalLaunchOp);
+
+REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp);
+REGISTER_KERNEL_BUILDER(Name("_XlaCompile")
+ .Device(DEVICE_GPU)
+ .HostMemory("constants")
+ .HostMemory("resources"),
+ XlaCompileOp);
+
+REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp);
+REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU), XlaRunOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h
new file mode 100644
index 0000000000..489d26eb30
--- /dev/null
+++ b/tensorflow/compiler/jit/kernels/xla_ops.h
@@ -0,0 +1,168 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
+#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
+
+#include "tensorflow/compiler/jit/xla_compilation_cache.h"
+#include "tensorflow/compiler/jit/xla_device.h"
+#include "tensorflow/compiler/jit/xla_launch_util.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/stream_executor_util.h"
+
+namespace tensorflow {
+
+// Holds some information about the platform on which an
+// XlaLaunch/_XlaCompile/_XlaRun op must run on.
+class XlaPlatformInfo {
+ public:
+ XlaPlatformInfo() : device_type_("") {}
+ explicit XlaPlatformInfo(const DeviceType device_type,
+ se::Platform::Id platform_id,
+ const XlaDevice::Metadata* xla_device_metadata,
+ std::unique_ptr<XlaAllocator> xla_allocator,
+ xla::DeviceMemoryAllocator* device_allocator)
+ : device_type_(device_type),
+ platform_id_(platform_id),
+ xla_device_metadata_(xla_device_metadata),
+ xla_allocator_(std::move(xla_allocator)),
+ device_allocator_(device_allocator) {
+ CHECK((device_allocator_ != nullptr) ^ (xla_allocator_.get() != nullptr));
+ }
+
+ XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default;
+
+ bool UseMultipleStreams() const {
+ return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
+ }
+
+ xla::DeviceMemoryAllocator* allocator() const {
+ return device_allocator_ ? device_allocator_ : xla_allocator_.get();
+ }
+ DeviceType device_type() const { return device_type_; }
+
+ // This is equal to xla_device_metadata()->platform()->id() if
+ // xla_device_metadata() is not nullptr.
+ se::Platform::Id platform_id() const { return platform_id_; }
+
+ // This may be null if the op this XlaPlatformInfo is for was not placed on an
+ // XLA device.
+ const XlaDevice::Metadata* xla_device_metadata() const {
+ return xla_device_metadata_;
+ }
+ bool is_on_xla_device() const { return xla_device_metadata() != nullptr; }
+
+ private:
+ DeviceType device_type_;
+ se::Platform::Id platform_id_;
+
+ // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the
+ // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the
+ // XlaLaunch/_XlaCompile/_XlaRun OpKernel.
+ const XlaDevice::Metadata* xla_device_metadata_;
+
+ // If the op associated with this XlaPlatformInfo is placed on an XLA device
+ // then device_allocator_ is the xla::Backend's memory allocator and
+ // xla_allocator_ is null. If the op is placed on a regular CPU or GPU device
+ // then device_allocator_ is null and xla_allocator_ points to an appropriate
+ // XlaAllocator instance.
+ std::unique_ptr<XlaAllocator> xla_allocator_;
+ xla::DeviceMemoryAllocator* device_allocator_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
+};
+
+// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
+// The only difference is that it does not require arguments to follow
+// the "constants, then regular args, then resources" order.
+// It takes vectors of constant and resource arguments explicitly.
+// It does not have corresponding OpDef because it is never present
+// in the GraphDef.
+// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
+// this kernel when asked to create a kernel for an XLA-compiled function.
+class XlaLocalLaunchBase : public OpKernel {
+ public:
+ XlaLocalLaunchBase(OpKernelConstruction* ctx,
+ const std::vector<int>& constants,
+ const std::vector<int>& resources,
+ const NameAttrList& function);
+ XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
+ XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
+ ~XlaLocalLaunchBase() override = default;
+
+ void Compute(OpKernelContext* ctx) override;
+
+ protected:
+ // Indexes of compile-time constant inputs
+ std::vector<int> constants_;
+ // Indexes of resource inputs
+ std::vector<int> resources_;
+
+ NameAttrList function_;
+ XlaPlatformInfo platform_info_;
+};
+
+// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
+// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
+// responsible for handling interactions with the TensorFlow executor.
+// Once all inputs are present, and their shapes are known, the op can
+// use a 'XlaCompilationCache' to compile and execute code which is specific
+// to the shapes of input Tensors.
+// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
+// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
+// memory.
+class XlaLocalLaunchOp : public XlaLocalLaunchBase {
+ public:
+ explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
+ ~XlaLocalLaunchOp() override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
+};
+
+class XlaCompileOp : public OpKernel {
+ public:
+ explicit XlaCompileOp(OpKernelConstruction* ctx);
+
+ void Compute(OpKernelContext* ctx) override;
+
+ private:
+ // Indexes of compile-time constant inputs
+ std::vector<int> constants_;
+ // Indexes of resource inputs
+ std::vector<int> resources_;
+
+ NameAttrList function_;
+
+ XlaPlatformInfo platform_info_;
+};
+
+class XlaRunOp : public OpKernel {
+ public:
+ explicit XlaRunOp(OpKernelConstruction* ctx);
+
+ void Compute(OpKernelContext* ctx) override;
+
+ private:
+ XlaPlatformInfo platform_info_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index e6cc6e52ae..133d982360 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -365,10 +365,13 @@ bool IsXlaFusable(const NodeDef& node) {
return elementwise_ops->count(node.op()) > 0;
}
+// Nodes that XLA can compile are put in `candidates`. Nodes put in
+// `isolated_nodes` must either be unclustered or be put in trivial single-node
+// clusters.
Status FindCompilationCandidates(
const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env,
const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
- OrderedNodeSet* candidates) {
+ OrderedNodeSet* candidates, gtl::FlatSet<Node*>* isolated_nodes) {
OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION,
@@ -411,6 +414,8 @@ Status FindCompilationCandidates(
DeviceType device_type("");
TF_RETURN_IF_ERROR(
DeviceToDeviceType(node->assigned_device_name(), &device_type));
+ VLOG(4) << "Device type for " << node->name() << ": "
+ << device_type.type_string();
if (is_compilable_fn && !is_compilable_fn(node, device_type)) {
// is_compilable_fn has already logged the reason if it returned false.
@@ -439,19 +444,56 @@ Status FindCompilationCandidates(
<< node->type_string();
continue;
}
- if (compile_time_const_nodes[node->id()] &&
- !registration->requires_compilation) {
+ if (compile_time_const_nodes[node->id()]) {
const OpDef* op_def;
TF_RETURN_IF_ERROR(
graph.op_registry()->LookUpOpDef(node->type_string(), &op_def));
if (op_def->is_stateful()) {
- // We need to be able to constant fold the nodes in
- // compile_time_const_nodes given constant inputs (required by XLA) and
- // therefore can't auto-cluster stateful ops since these can never be
- // constant folded.
- VLOG(2) << "Rejecting " << node->name()
- << ": must-be-constant stateful op";
- continue;
+ // It is easiest to demonstrate the problem we're trying to solve with
+ // an example. Say we have this graph:
+ //
+ // shape = RandomUniformInt();
+ // reshape = Reshape(input, shape)
+ //
+ // Both RandomUniformInt and Reshape are compilable by XLA so, absent
+ // any other reason, we will try to put both shape and reshape in the
+ // same cluster. However, since XLA only supports statically shaped
+ // values, it will expect to be able to constant fold `shape` to get a
+ // static shape for `reshape`. This is a problem because side-effecting
+ // ops like RandomUniformInt() cannot be constant folded. We fix this
+ // by putting `shape` and `reshape` in different clusters, which results
+ // in us recompiling `reshape`'s cluster for every new value of `shape`,
+ // making `reshape` statically sized within each compilation. We
+ // simplify the solution even further by disallowing operations like
+ // `shape` from being part of *any* non-trivial cluster. They're either
+ // not compiled by XLA altogether or, if assigned to an XLA_* device
+ // with "must compile" semantics, compiled into a trivial single-op
+ // cluster. This approach leaves some room for improvement, and we can
+ // consider implementing a more aggressive data-flow-analysis based
+ // solution in the future if needed.
+ //
+ // One ugly problem we have to contend with: certain sets of ops *have*
+ // to be in the same cluster because values flowing between them have
+ // types that can't be live-in or live-out of a cluster. These ops are:
+ //
+ // - TensorArray ops operating on the same TensorArray instance.
+ // - Stack ops operating on the same Stack instance.
+ //
+ // To work around this we avoid isolating these specific ops. Because
+ // of this concession it is unsound to auto-cluster them because then
+ // we'd create clusters we could not compile (because we can't constant
+ // fold, say, a TensorArrayRead or a StackPopV2). But we don't
+ // auto-cluster these operations today so we're good for now.
+ const XlaResourceOpInfo* op_info =
+ GetResourceOpInfoForOp(node->type_string());
+ bool is_tensor_array_or_stack_op =
+ op_info && op_info->resource_kind() != XlaResourceKind::kVariable;
+ if (!is_tensor_array_or_stack_op) {
+ VLOG(2) << "Isolating " << node->name()
+ << ": must-be-constant stateful op";
+ isolated_nodes->insert(node);
+ // Keep going and execute all the other checks.
+ }
}
}
// We don't auto-cluster functional control flow nodes containing resource
@@ -807,11 +849,12 @@ Status MarkForCompilationPass::RunImpl(
Graph* graph = options.graph->get();
OrderedNodeSet compilation_candidates;
+ gtl::FlatSet<Node*> isolated_nodes;
TF_RETURN_IF_ERROR(FindCompilationCandidates(
*graph, options.flib_def,
(options.session_options != nullptr) ? options.session_options->env
: Env::Default(),
- is_compilable_fn, &compilation_candidates));
+ is_compilable_fn, &compilation_candidates, &isolated_nodes));
if (compilation_candidates.empty()) {
VLOG(2) << "No compilable candidates";
@@ -856,6 +899,11 @@ Status MarkForCompilationPass::RunImpl(
"Found control flow node in clustering worklist: ",
node_from->type_string());
}
+
+ if (isolated_nodes.count(node_from)) {
+ continue;
+ }
+
string from_scope;
string to_scope;
for (int to : cycles.Successors(from)) {
@@ -873,6 +921,9 @@ Status MarkForCompilationPass::RunImpl(
node_to->assigned_device_name()) {
continue;
}
+ if (isolated_nodes.count(node_to)) {
+ continue;
+ }
// Look for an _XlaScope on both nodes. If both nodes have a
// scope and the scopes do not match, do not cluster along this
// edge. This restriction is overridden if the global_jit_level is ON. If
@@ -931,6 +982,11 @@ Status MarkForCompilationPass::RunImpl(
// Names for each cluster.
std::unordered_map<int, string> cluster_names;
+ if (flags->tf_xla_clustering_debug) {
+ dump_graph::DumpGraphToFile("before_mark_for_compilation", **options.graph,
+ options.flib_def);
+ }
+
// Mark clusters for compilation that:
// * are placed on a device that requires compilation (an XlaDevice),
// * are explicitly marked for compilation (_XlaCompile=true), or
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index c59770a4c8..4f9145b479 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -894,5 +894,71 @@ TEST(XlaCompilationTest, RandomShapeWithFunc) {
EXPECT_EQ(clusters["fn_call"], "");
}
+TEST(XlaCompilationTest, RandomShapeOnXlaDevice) {
+ absl::string_view xla_gpu_device =
+ "/job:worker/replica:0/task:0/device:XLA_GPU:0";
+
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output shape_shape =
+ ops::Const(root.WithOpName("test/shape_shape"), {2}, {1});
+ Output shape =
+ ops::RandomUniformInt(root.WithOpName("test/shape_rng"), shape_shape,
+ ops::Const(root.WithOpName("test/minval"), 1),
+ ops::Const(root.WithOpName("test/maxval"), 20));
+ Output reshape_input =
+ ops::Placeholder(root.WithOpName("test/reshape_input"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({500, 500})));
+ Output reshape =
+ ops::Reshape(root.WithOpName("test/reshape"), reshape_input, shape);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ for (Node* n : graph->nodes()) {
+ if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
+ n->set_assigned_device_name(string(xla_gpu_device));
+ }
+ }
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+ EXPECT_NE(clusters["test/shape_rng"], "");
+ EXPECT_NE(clusters["test/reshape"], "");
+ EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]);
+}
+
+TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) {
+ absl::string_view xla_gpu_device =
+ "/job:worker/replica:0/task:0/device:XLA_GPU:0";
+ Scope root = Scope::NewRootScope().ExitOnError();
+ ops::TensorArray tensor_array(root.WithOpName("test/tensor_array"), 1,
+ DT_INT32);
+ Output zero = ops::Const(root.WithOpName("test/zero"), 0);
+ ops::TensorArrayWrite tensor_array_write(
+ root.WithOpName("test/write"), tensor_array.handle, zero,
+ ops::Const(root.WithOpName("test/forty_two"), 42.0f), tensor_array.flow);
+ Output tensor_array_read =
+ ops::TensorArrayRead(root.WithOpName("test/read"), tensor_array.handle,
+ zero, tensor_array_write.flow_out, DT_INT32);
+ Output reshape =
+ ops::Reshape(root.WithOpName("test/reshape"),
+ ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT),
+ tensor_array_read);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ for (Node* n : graph->nodes()) {
+ if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
+ n->set_assigned_device_name(string(xla_gpu_device));
+ }
+ }
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+ EXPECT_NE(clusters["test/read"], "");
+ EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
index 65669877f7..d56d0f8ccf 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
@@ -14,18 +14,35 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
SessionOptions* session_options) {
- // Assign all nodes to the CPU device.
+ // Assign all unassigned nodes to the CPU device.
static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
for (Node* n : (*graph)->nodes()) {
- n->set_assigned_device_name(kCpuDevice);
+ if (n->assigned_device_name().empty()) {
+ n->set_assigned_device_name(kCpuDevice);
+ }
}
+ // Call AddDevices to register the XLA devices.
+ //
+ // It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to
+ // make this more direct, but probably not worth it solely for this test.
+ std::vector<Device*> devices;
+ TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices));
+
+ auto delete_devices = gtl::MakeCleanup([&] {
+ for (Device* d : devices) {
+ delete d;
+ }
+ });
+
GraphOptimizationPassOptions opt_options;
opt_options.graph = graph;
opt_options.session_options = session_options;
diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD
index 13804c6a05..f72224545b 100644
--- a/tensorflow/compiler/jit/ops/BUILD
+++ b/tensorflow/compiler/jit/ops/BUILD
@@ -4,9 +4,17 @@ package(
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
)
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
+
cc_library(
name = "xla_ops",
srcs = ["xla_ops.cc"],
deps = ["//tensorflow/core:framework"],
alwayslink = 1,
)
+
+tf_gen_op_wrapper_py(
+ name = "xla_ops_wrapper_py",
+ out = "xla_ops.py",
+ deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
+)
diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc
index 1a29c3caab..bcd1a29b1f 100644
--- a/tensorflow/compiler/jit/ops/xla_ops.cc
+++ b/tensorflow/compiler/jit/ops/xla_ops.cc
@@ -51,4 +51,43 @@ REGISTER_OP("XlaClusterOutput")
"Operator that connects the output of an XLA computation to other "
"consumer graph nodes.");
+REGISTER_OP("_XlaCompile")
+ .Input("constants: Tconstants")
+ .Attr("Tconstants: list(type) >= 0")
+ .Input("args: Targs")
+ .Attr("Targs: list(type) >= 0")
+ .Input("resources: Nresources * resource")
+ .Attr("Nresources: int >= 0")
+ .Output("key: string")
+ .Output("compilation_successful: bool")
+ .Attr("function: func")
+ // The compilation cache is stateful.
+ .SetIsStateful()
+ .Doc(R"(XLA Compile Op. For use by the XLA JIT only.
+
+Compiles a TensorFlow function into an XLA LocalExecutable and returns a key
+that _XlaRun can use to look up the LocalExecutable and execute it.
+
+key: A key that can be used to look up the local executable compiled by the
+ node and associated metadata.
+
+compilation_successful: True iff the compilation was successful. Always true
+for now.
+)");
+
+REGISTER_OP("_XlaRun")
+ .Input("args: Targs")
+ .Attr("Targs: list(type) >= 0")
+ .Output("results: Tresults")
+ .Attr("Tresults: list(type) >= 0")
+ .Input("key: string")
+ // XLA random-number generation ops are stateful.
+ // TODO(phawkins): create stateful and non-stateful variants of _XlaRun.
+ .SetIsStateful()
+ .Doc(R"(XLA Run Op. For use by the XLA JIT only.
+
+Executes a TensorFlow function previously compiled into a LocalExecutable by an
+_XlaCompile op.
+)");
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
index 35872daa65..0feb73a89e 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
@@ -60,9 +60,9 @@ class FakeBinaryOp : public OpKernel {
void Compute(OpKernelContext* ctx) override { CHECK(false); }
};
-class FakeResourceVarUpdateOp : public OpKernel {
+class FakeResourceUpdateOp : public OpKernel {
public:
- explicit FakeResourceVarUpdateOp(OpKernelConstruction* context)
+ explicit FakeResourceUpdateOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* ctx) override { CHECK(false); }
@@ -74,10 +74,9 @@ REGISTER_KERNEL_BUILDER(Name("FakeBinary")
.HostMemory("host_out"),
FakeBinaryOp);
-REGISTER_KERNEL_BUILDER(Name("FakeResourceVarUpdate")
- .Device(DEVICE_CPU)
- .HostMemory("something_else"),
- FakeResourceVarUpdateOp);
+REGISTER_KERNEL_BUILDER(
+ Name("FakeResourceUpdate").Device(DEVICE_CPU).HostMemory("something_else"),
+ FakeResourceUpdateOp);
Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
FixupSourceAndSinkEdges(graph->get());
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index 3ba48e8c31..b98c0cb028 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -34,6 +34,7 @@ std::map<int, OptionalTensor> GetVariables(OpKernelContext* ctx) {
OptionalTensor& optional = variables[i];
optional.name = handle.name();
if (LookupResource(ctx, handle, &variable).ok()) {
+ core::ScopedUnref scoped_unref(variable);
tf_shared_lock lock(*variable->mu());
optional.present = true;
optional.value = *variable->tensor();
@@ -58,7 +59,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
/*allocate_xla_tensors=*/true,
/*use_multiple_streams=*/metadata.UseMultipleStreams());
- launch_context.PopulateInputs(ctx, result, variables);
+ launch_context.PopulateInputs(ctx, result, variables,
+ /*missing_ctx_input_prefix=*/0);
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
@@ -79,7 +81,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
TF_RETURN_IF_ERROR(run_result.status());
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
- ctx, result, run_result.ConsumeValueOrDie()));
+ ctx, result, run_result.ConsumeValueOrDie(),
+ /*missing_ctx_input_prefix=*/0));
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
index 7e159e3171..003c1d8081 100644
--- a/tensorflow/compiler/jit/xla_cpu_device.cc
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -16,7 +16,7 @@ limitations under the License.
// Registers the XLA_CPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "Host" (CPU) backend.
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h"
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/compiler/jit/xla_device.h"
@@ -65,10 +65,14 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
// Kernel registrations
-constexpr std::array<DataType, 7> kAllXlaCpuTypes = {
- {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
+constexpr std::array<DataType, 12> kAllXlaCpuTypes = {
+ {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
+ DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
+REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes);
+REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_CPU, XlaRunOp, kAllXlaCpuTypes);
+
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 51797def04..0824c4644e 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -373,7 +373,7 @@ Status XlaDevice::FillContextMap(const Graph* graph,
void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
<< op_kernel->type_string();
- TracingDevice::Compute(op_kernel, context);
+ op_kernel->Compute(context);
}
void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
@@ -434,6 +434,16 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
return status;
}
+void XlaDevice::SetRequiresSyncOnCompletion(bool sync_on_completion) {
+ mutex_lock lock(mu_);
+ sync_on_completion_ = sync_on_completion;
+}
+
+bool XlaDevice::RequiresSyncOnCompletion() const {
+ mutex_lock lock(mu_);
+ return sync_on_completion_;
+}
+
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
const char* jit_device) {
// Any op assigned to the device that isn't rewritten by the graph rewriter
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index 92891ffa8c..0f06b3fc80 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -151,6 +151,12 @@ class XlaDevice : public LocalDevice {
// information for GPU and TPU devices.
Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_);
+ // Instructs this XlaDevice to return 'sync_on_completion' for
+ // RequiresSyncOnCompletion().
+ void SetRequiresSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_);
+
+ bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_);
+
private:
xla::LocalClient* client() const;
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
@@ -165,7 +171,7 @@ class XlaDevice : public LocalDevice {
static Status GetMetadataFromDevice(DeviceBase* device,
const XlaDevice::Metadata** metadata);
- mutex mu_;
+ mutable mutex mu_;
// The metadata of this XlaDevice.
const Metadata xla_metadata_;
// Which hardware device in the client's platform this XlaDevice controls.
@@ -207,6 +213,10 @@ class XlaDevice : public LocalDevice {
// Thread pool used for running closures
std::unique_ptr<thread::ThreadPool> thread_pool_;
+
+ // True if the device requires XlaDevice::Sync to be called on completion
+ // regardless of status.
+ bool sync_on_completion_ GUARDED_BY(mu_) = false;
};
// Builds OpKernel registrations on 'device' for the JIT operators
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index 49c8582682..6967ad1f03 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -65,6 +65,16 @@ class XlaAssignVariableOp : public AsyncOpKernel {
.HostMemory("resources"), \
KERNEL);
+#define REGISTER_XLA_COMPILE_KERNEL(DEVICE, KERNEL, TYPES) \
+ REGISTER_KERNEL_BUILDER(Name("_XlaCompile") \
+ .Device(DEVICE) \
+ .HostMemory("constants") \
+ .HostMemory("resources"), \
+ KERNEL);
+
+#define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \
+ REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL);
+
#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \
@@ -90,9 +100,15 @@ class XlaAssignVariableOp : public AsyncOpKernel {
Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), \
ResourceHandleOp<Var>); \
REGISTER_KERNEL_BUILDER( \
+ Name("_VarHandlesOp").Device(DEVICE).HostMemory("resources"), \
+ ResourceHandlesOp<Var>); \
+ REGISTER_KERNEL_BUILDER( \
Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \
ReadVariableOp); \
REGISTER_KERNEL_BUILDER( \
+ Name("_ReadVariablesOp").Device(DEVICE).HostMemory("resources"), \
+ ReadVariablesOp); \
+ REGISTER_KERNEL_BUILDER( \
Name("DestroyResourceOp").Device(DEVICE).HostMemory("resource"), \
DestroyResourceOp); \
REGISTER_KERNEL_BUILDER(Name("Shape") \
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index ef4466f005..60979556a3 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -16,7 +16,7 @@ limitations under the License.
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "CUDA" (GPU) backend.
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -74,11 +74,14 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
// Kernel registrations
-constexpr std::array<DataType, 8> kAllXlaGpuTypes = {
- {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL,
- DT_BFLOAT16}};
+constexpr std::array<DataType, 13> kAllXlaGpuTypes = {
+ {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
+ DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
+REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
+REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes);
+
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc
index 4574559674..19e681af0c 100644
--- a/tensorflow/compiler/jit/xla_interpreter_device.cc
+++ b/tensorflow/compiler/jit/xla_interpreter_device.cc
@@ -15,7 +15,7 @@ limitations under the License.
// Registers the XLA_INTERPRETER device which exposes the XLA Interpreter.
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -72,6 +72,10 @@ static bool OpFilter(KernelDef* kdef) { return true; }
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_INTERPRETER, XlaLocalLaunchOp,
kExecAllTypes);
+REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_INTERPRETER, XlaCompileOp,
+ kExecAllTypes);
+REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_INTERPRETER, XlaRunOp, kExecAllTypes);
+
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes);
REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter);
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index affeab4a8c..4f6fc4e068 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -42,13 +42,14 @@ using xla::ShapedBuffer;
} // anonymous namespace
std::map<int, OptionalTensor> SnapshotResourceVariables(
- OpKernelContext* ctx, const std::vector<int>& variables) {
+ OpKernelContext* ctx, absl::Span<const int> variables) {
std::map<int, OptionalTensor> snapshot;
for (int i : variables) {
Var* variable = nullptr;
ResourceHandle handle = HandleFromInput(ctx, i);
OptionalTensor& tensor = snapshot[i];
if (LookupResource(ctx, handle, &variable).ok()) {
+ core::ScopedUnref scoped_unref(variable);
tf_shared_lock lock(*variable->mu());
tensor.name = handle.name();
tensor.present = true;
@@ -133,7 +134,8 @@ XlaComputationLaunchContext::XlaComputationLaunchContext(
void XlaComputationLaunchContext::PopulateInputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
- const std::map<int, OptionalTensor>& variables) {
+ const std::map<int, OptionalTensor>& variables,
+ int missing_ctx_input_prefix) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
// Build ShapedBuffers that point directly to the Tensor buffers.
@@ -145,12 +147,13 @@ void XlaComputationLaunchContext::PopulateInputs(
const Tensor* t;
for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
int arg_num = kernel->input_mapping[i];
+ DCHECK_GE(arg_num, missing_ctx_input_prefix);
const xla::Shape& shape = kernel->xla_input_shapes[i];
if (variables.count(arg_num)) {
t = &(variables.at(arg_num).value);
CHECK(t);
} else {
- t = &(ctx->input(arg_num));
+ t = &(ctx->input(arg_num - missing_ctx_input_prefix));
}
if (use_multiple_streams_) {
@@ -187,7 +190,7 @@ void XlaComputationLaunchContext::PopulateInputs(
Status XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
- ScopedShapedBuffer output) {
+ ScopedShapedBuffer output, int missing_ctx_input_prefix) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
@@ -275,6 +278,8 @@ Status XlaComputationLaunchContext::PopulateOutputs(
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
<< DataTypeString(type);
if (type == DT_RESOURCE) {
+ TF_RET_CHECK(kernel->outputs[i].input_index >= 0)
+ << "Invalid input for outputs " << i;
ctx->set_output(i, ctx->input(kernel->outputs[i].input_index));
} else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
@@ -313,7 +318,8 @@ Status XlaComputationLaunchContext::PopulateOutputs(
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
Allocator* allocator = ctx->device()->GetAllocator({});
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
- if (write.input_index < 0 || write.input_index >= ctx->num_inputs()) {
+ int actual_input_index = write.input_index - missing_ctx_input_prefix;
+ if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
return errors::Internal("Invalid input index for variable write.");
}
@@ -323,7 +329,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
- ctx, HandleFromInput(ctx, write.input_index), &variable,
+ ctx, HandleFromInput(ctx, actual_input_index), &variable,
[&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 7ac275fab8..326d70a027 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
class XlaAllocator;
@@ -43,7 +44,7 @@ class XlaAllocator;
// resource variable is not initialized, the corresponding OptionalTensor
// will have its `present` field set to false.
std::map<int, OptionalTensor> SnapshotResourceVariables(
- OpKernelContext* ctx, const std::vector<int>& variables);
+ OpKernelContext* ctx, absl::Span<const int> variables);
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
// Assumes that the Tensorflow allocator permits asynchronous deallocation:
@@ -88,14 +89,24 @@ class XlaComputationLaunchContext {
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
// `variables` is a map from TensorFlow argument number to resource variable.
+ //
+ // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
+ // missing and adjusts input indices accordingly. All elements in kernel's
+ // input_mapping must be greater than or equal to `missing_ctx_input_prefix`
+ // (in other words, no inputs actually required by the kernel can be missing).
void PopulateInputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* kernel,
- const std::map<int, OptionalTensor>& variables);
+ const std::map<int, OptionalTensor>& variables,
+ int missing_ctx_input_prefix);
// Given the XLA output in `output`, populate all outputs of `ctx`.
+ //
+ // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
+ // missing and adjusts input indices accordingly.
Status PopulateOutputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* kernel,
- xla::ScopedShapedBuffer output);
+ xla::ScopedShapedBuffer output,
+ int missing_ctx_input_prefix);
// Return the argument list. Only valid after PopulateInputs() has been
// called.
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 97ed554171..3cf74fa788 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -978,7 +978,7 @@ tf_xla_py_test(
name = "gather_test",
size = "medium",
srcs = ["gather_test.py"],
- tags = ["noasan"], # times out, http://b/78599043
+ tags = ["optonly"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@@ -1198,6 +1198,19 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "quantized_ops_test",
+ size = "small",
+ srcs = ["quantized_ops_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
name = "xla_ops_test",
size = "medium",
srcs = ["xla_ops_test.py"],
diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py
index 4155342787..68f52e796c 100644
--- a/tensorflow/compiler/tests/argminmax_test.py
+++ b/tensorflow/compiler/tests/argminmax_test.py
@@ -50,12 +50,12 @@ class ArgMinMaxTest(xla_test.XLATestCase):
def testArgMinMax(self):
# Complex numbers do not support argmin/argmax.
- minmax_types = set(self.numeric_types) - set(self.complex_types)
+ minmax_types = self.all_types & {np.int32, np.int64}
for dtype in minmax_types:
# output_type is a numpy data type that is used to specify the desired
# output type of the op as well as to convert the Python number to the
# array scalar of the type.
- for output_type in self.int_types:
+ for output_type in minmax_types:
self._assertOpOutputMatchesExpected(
math_ops.argmax,
axis=0,
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 17280e445b..1b39d53dc0 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -210,7 +210,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
equality_test=self.ListsAreClose)
def testIntOps(self):
- for dtype in self.int_types:
+ for dtype in self.signed_int_types:
self._testBinary(
gen_math_ops.truncate_div,
np.array([3, 3, -1, -9, -8], dtype=dtype),
@@ -287,7 +287,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
dtype(7),
expected=np.array([[-6], [-5]], dtype=dtype))
- if dtype not in self.complex_types: # min/max not supported for complex
+ # min/max not supported for complex
+ if dtype not in self.complex_types | {np.uint8, np.int8}:
self._testBinary(
math_ops.maximum,
np.array([1, 2], dtype=dtype),
@@ -337,7 +338,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
expected=np.array([[70], [14]], dtype=dtype))
# Complex support for squared_difference is incidental, see b/68205550
- if dtype not in self.complex_types:
+ if dtype not in self.complex_types | {np.uint8, np.int8}:
self._testBinary(
math_ops.squared_difference,
np.array([1, 2], dtype=dtype),
@@ -559,6 +560,13 @@ class BinaryOpsTest(xla_test.XLATestCase):
dtype(2),
expected=np.array([[5], [2]], dtype=dtype))
+ if dtype in [np.float32, np.float64]:
+ nums = np.arange(-10, 10, .25, dtype=dtype).reshape(80, 1)
+ divs = np.arange(-3, 3, .25, dtype=dtype).reshape(1, 24)
+ np_result = np.true_divide(nums, divs)
+ np_result[:, divs[0] == 0] = 0
+ self._testBinary(gen_math_ops.div_no_nan, nums, divs, expected=np_result)
+
if dtype not in self.complex_types: # floordiv unsupported for complex.
self._testBinary(
gen_math_ops.floor_div,
@@ -567,7 +575,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
expected=np.array([1, -2, -1, -5, 2], dtype=dtype))
def testIntDivision(self):
- for dtype in self.int_types:
+ for dtype in self.signed_int_types:
self._testDivision(dtype)
def testFloatDivision(self):
@@ -588,7 +596,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
expected=np.array([1, 1, -1, 0], dtype=dtype))
def testIntRemainder(self):
- for dtype in self.int_types:
+ for dtype in self.signed_int_types - {np.int8}:
self._testRemainder(dtype)
def testFloatRemainder(self):
@@ -1437,6 +1445,13 @@ class BinaryOpsTest(xla_test.XLATestCase):
np.array([4, 0], dtype=np.int32),
expected=np.zeros([4, 0], dtype=dtype))
+ x = np.arange(3).reshape((3, 1, 1, 1)).astype(dtype)
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array((3, 7, 8, 9), dtype=np.int32),
+ expected=np.tile(x, (1, 7, 8, 9)))
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl
index a76f136736..1d3979b21b 100644
--- a/tensorflow/compiler/tests/build_defs.bzl
+++ b/tensorflow/compiler/tests/build_defs.bzl
@@ -2,6 +2,10 @@
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
load("//tensorflow/compiler/tests:plugin.bzl", "plugins")
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
def all_backends():
b = ["cpu"] + plugins.keys()
@@ -58,14 +62,14 @@ def tf_xla_py_test(
if backend == "cpu":
backend_args += [
"--test_device=XLA_CPU",
- "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64",
+ "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64",
]
elif backend == "gpu":
backend_args += [
"--test_device=XLA_GPU",
- "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16",
+ "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16",
]
- backend_tags += ["requires-gpu-sm35"]
+ backend_tags += tf_cuda_tests_tags()
elif backend in plugins:
backend_args += [
"--test_device=" + plugins[backend]["device"],
diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py
index 0af74c2d8f..9390870e07 100644
--- a/tensorflow/compiler/tests/dense_layer_test.py
+++ b/tensorflow/compiler/tests/dense_layer_test.py
@@ -45,17 +45,21 @@ def InLabels(labels, substr):
return any([substr in x for x in labels])
-def XlaLaunchOpCount(labels):
- """Count how many XlaLaunch labels are present."""
- return sum("XlaLaunch(" in x for x in labels)
+class DenseLayerTest(test.TestCase):
+ def countXlaOps(self, labels):
+ """Count how many XlaCompile/XlaRun labels are present."""
+ xla_compile_count = sum("XlaCompile(" in x for x in labels)
+ xla_run_count = sum("XlaRun(" in x for x in labels)
+ self.assertEqual(xla_compile_count, xla_run_count)
+ return xla_run_count
-class DenseLayerTest(test.TestCase):
def testDenseLayerAutoJit(self):
"""Tests dense layer compilation in auto-jit mode.
- Dense layer should be compiled into a single XlaLaunch op in auto-jit mode.
+ Dense layer should be compiled into a single XlaCompile/XlaRun op pair in
+ auto-jit mode.
"""
os.environ["TF_XLA_FLAGS"] = (
@@ -77,14 +81,14 @@ class DenseLayerTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = GetRunMetadataLabels(run_metadata)
- self.assertEqual(1, XlaLaunchOpCount(labels))
+ self.assertEqual(1, self.countXlaOps(labels))
self.assertFalse(InLabels(labels, "MatMult"))
def testDenseLayerJitScopeDefinedShape(self):
"""Tests that the dense layer node is properly compiled in jit scope.
Dense layer with static shape input tensor should be compiled into a single
- XlaLaunch op by XLA.
+ XlaCompile/XlaRun op pair by XLA.
"""
with self.cached_session() as sess:
@@ -101,7 +105,7 @@ class DenseLayerTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = GetRunMetadataLabels(run_metadata)
- self.assertEqual(1, XlaLaunchOpCount(labels))
+ self.assertEqual(1, self.countXlaOps(labels))
# No need to check whether ListDiff is compiled or not because ListDiff op
# is not used when input tensor shape is fully defined.
@@ -111,7 +115,8 @@ class DenseLayerTest(test.TestCase):
Dense layer uses shape op to get shape of input tensor if its shape is not
fully defined. XLA does not cluster shape op with other operators. But in
experimental_jit_scope, XLA is forced to compile shape op into its own
- cluster, causing dense layer to be split into TWO XlaLaunch ops.
+ cluster, causing dense layer to be split into TWO XlaCompile/XlaRun op
+ pairs.
"""
with self.cached_session() as sess:
@@ -128,7 +133,7 @@ class DenseLayerTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = GetRunMetadataLabels(run_metadata)
- self.assertEqual(2, XlaLaunchOpCount(labels))
+ self.assertEqual(2, self.countXlaOps(labels))
self.assertFalse(InLabels(labels, "MatMult"))
diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py
index 8c018cccb8..374942a0b3 100644
--- a/tensorflow/compiler/tests/fused_batchnorm_test.py
+++ b/tensorflow/compiler/tests/fused_batchnorm_test.py
@@ -29,6 +29,11 @@ from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn
from tensorflow.python.platform import test
+DATA_FORMATS = (
+ ("_data_format_NHWC", "NHWC"),
+ ("_data_format_NCHW", "NCHW"),
+)
+
class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
@@ -65,12 +70,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
grad_offset = np.sum(grad_y, axis=(0, 1, 2))
return grad_x, grad_scale, grad_offset
- @parameterized.named_parameters(
- ("_data_format_NHWC", "NHWC"),
- ("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
- )
+ @parameterized.named_parameters(*DATA_FORMATS)
def testInference(self, data_format):
channel = 3
x_shape = [2, 2, 6, channel]
@@ -170,30 +170,15 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
self.assertAllClose(y_val, y_ref_converted, atol=1e-3)
self.assertAllClose(var_val, var_ref, atol=1e-3)
- @parameterized.named_parameters(
- ("_data_format_NHWC", "NHWC"),
- ("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
- )
+ @parameterized.named_parameters(*DATA_FORMATS)
def testLearning(self, data_format):
self._testLearning(False, data_format)
- @parameterized.named_parameters(
- ("_data_format_NHWC", "NHWC"),
- ("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
- )
+ @parameterized.named_parameters(*DATA_FORMATS)
def testLearningWithGradientChecker(self, data_format):
self._testLearning(True, data_format)
- @parameterized.named_parameters(
- ("_data_format_NHWC", "NHWC"),
- ("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
- )
+ @parameterized.named_parameters(*DATA_FORMATS)
def testGradientTraining(self, data_format):
# TODO(b/64270657): Use gradient_checker here in addition to comparing with
# this reference implementation.
@@ -241,12 +226,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2)
self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3)
- @parameterized.named_parameters(
- ("_data_format_NHWC", "NHWC"),
- ("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
- )
+ @parameterized.named_parameters(*DATA_FORMATS)
def testGradientInference(self, data_format):
# TODO(b/64270657): Use gradient_checker here in addition to comparing with
# this reference implementation.
diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py
index 089d95daab..a38e1edafe 100644
--- a/tensorflow/compiler/tests/gather_test.py
+++ b/tensorflow/compiler/tests/gather_test.py
@@ -51,7 +51,7 @@ class GatherTest(xla_test.XLATestCase):
indices_tf = constant_op.constant(indices)
gather_t = array_ops.gather(params, indices_tf)
gather_val = session.run(gather_t, feed_dict={params: params_np})
- np_val = params_np[indices]
+ np_val = constant_op.constant(params_np[indices])
self.assertAllEqual(np_val, gather_val)
def testScalar2D(self):
@@ -65,7 +65,8 @@ class GatherTest(xla_test.XLATestCase):
indices = constant_op.constant(2)
gather_t = array_ops.gather(params, indices, axis=axis)
gather_val = session.run(gather_t, feed_dict={params: params_np})
- expected = np.take(params_np, 2, axis=axis)
+ expected = constant_op.constant(
+ np.take(params_np, 2, axis=axis), dtype)
self.assertAllEqual(expected, gather_val)
def testSimpleTwoD32(self):
@@ -80,7 +81,8 @@ class GatherTest(xla_test.XLATestCase):
indices = constant_op.constant([0, 1, 0, 2])
gather_t = array_ops.gather(params, indices, axis=axis)
gather_val = session.run(gather_t, feed_dict={params: params_np})
- expected = np.take(params_np, [0, 1, 0, 2], axis=axis)
+ expected = constant_op.constant(
+ np.take(params_np, [0, 1, 0, 2], axis=axis), dtype)
self.assertAllEqual(expected, gather_val)
def testSimpleTwoD32_Int64Indices(self):
@@ -103,7 +105,8 @@ class GatherTest(xla_test.XLATestCase):
params: params_np,
indices: indices_np
})
- expected = np.take(params_np, [0, 1, 0, 2], axis=axis)
+ expected = constant_op.constant(
+ np.take(params_np, [0, 1, 0, 2], axis=axis), dtype)
self.assertAllEqual(expected, gather_val)
def testHigherRank(self):
@@ -119,7 +122,8 @@ class GatherTest(xla_test.XLATestCase):
tf_indices = constant_op.constant(indices, dtype=dtypes.int32)
gather = array_ops.gather(tf_params, tf_indices, axis=axis)
gather_value = sess.run(gather, feed_dict={tf_params: params})
- gather_np = np.take(params, indices, axis=axis)
+ gather_np = constant_op.constant(
+ np.take(params, indices, axis=axis), dtype)
self.assertAllEqual(gather_np, gather_value)
def testIndicesWithDifferentDimensions(self):
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py
index 6fe5a66e0e..68fdb5caf4 100644
--- a/tensorflow/compiler/tests/image_ops_test.py
+++ b/tensorflow/compiler/tests/image_ops_test.py
@@ -605,10 +605,6 @@ class ResizeBilinearTest(xla_test.XLATestCase):
class NonMaxSuppressionTest(xla_test.XLATestCase):
def testNMS128From1024(self):
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
with compat.forward_compatibility_horizon(2018, 8, 8):
num_boxes = 1024
boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4")
@@ -644,10 +640,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
self.assertEqual(indices_tf.size, max_output_size)
def testNMS3From6Boxes(self):
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
with compat.forward_compatibility_horizon(2018, 8, 8):
# Three boxes are selected based on IOU.
boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
@@ -693,10 +685,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
# Three boxes are selected based on IOU.
# One is filtered out by score threshold.
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
with compat.forward_compatibility_horizon(2018, 8, 8):
boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
[0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
@@ -736,6 +724,49 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
self.assertEqual(num_valid, 2)
self.assertAllClose(indices_tf[:num_valid], [3, 0])
+ def testNMS3Then1WithScoreMaxThresh(self):
+ # Three boxes are selected based on IOU.
+ # One is filtered out by score threshold.
+ # One is filtered out by max_output_size.
+
+ with compat.forward_compatibility_horizon(2018, 8, 8):
+ boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
+ [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
+ boxes_np = np.array(boxes_data, dtype=np.float32)
+
+ scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
+ scores_np = np.array(scores_data, dtype=np.float32)
+ max_output_size = 1
+ iou_threshold_np = np.array(0.5, dtype=np.float32)
+ score_threshold_np = np.array(0.4, dtype=np.float32)
+
+ with self.cached_session() as sess:
+ boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
+ scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
+ iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
+ iou_threshold_np.shape)
+ score_threshold = array_ops.placeholder(score_threshold_np.dtype,
+ score_threshold_np.shape)
+ with self.test_scope():
+ selected_indices = image_ops.non_max_suppression_padded(
+ boxes=boxes,
+ scores=scores,
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ score_threshold=score_threshold,
+ pad_to_max_output_size=True)
+ inputs_feed = {
+ boxes: boxes_np,
+ scores: scores_np,
+ iou_threshold: iou_threshold_np,
+ score_threshold: score_threshold_np
+ }
+ (indices_tf, num_valid) = sess.run(
+ selected_indices, feed_dict=inputs_feed)
+
+ self.assertEqual(indices_tf.size, max_output_size)
+ self.assertEqual(num_valid, 1)
+ self.assertAllClose(indices_tf[:num_valid], [3])
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py
index 0839fb123e..de68ff0e32 100644
--- a/tensorflow/compiler/tests/jit_test.py
+++ b/tensorflow/compiler/tests/jit_test.py
@@ -77,11 +77,11 @@ def InLabels(labels, substr):
return any([substr in x for x in labels])
-def MetadataHasXlaLaunch(run_metadata):
- """Returns true if there is a XlaLaunch kernel in run_metadata's timeline."""
+def MetadataHasXlaOp(run_metadata):
+ """Returns true if there are XlaRun kernels in run_metadata's timeline."""
# TODO(phawkins): find a less hacky way to test whether a kernel ran.
- return InLabels(RunMetadataLabels(run_metadata), "XlaLaunch")
+ return InLabels(RunMetadataLabels(run_metadata), "XlaRun")
class JitLaunchTest(test.TestCase):
@@ -90,9 +90,10 @@ class JitLaunchTest(test.TestCase):
# Verifies that the outputs match and that XLA was invoked. 'fn' must take
# the same number of tensors as arguments that are in 'args', and must return
# a tuple of output tensors.
- # If 'require_kernel_launch' is True, then we verify that a XlaLaunch node
- # actually ran. However, it is sometimes possible for XlaLaunch ops to be
- # constant-folded away, so the check is optional.
+ #
+ # If 'require_kernel_launch' is True, then we verify that an XlaCompile/XlaRun
+ # node actually ran. However, it is sometimes possible for XlaCompile/XlaRun
+ # ops to be constant-folded away, so the check is optional.
def _compare(self, fn, args, require_kernel_launch=True, noinline=None):
with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
placeholders = []
@@ -115,7 +116,7 @@ class JitLaunchTest(test.TestCase):
print("Compiled Result {}".format(compiled))
if require_kernel_launch:
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
direct = sess.run(direct_op, feeds)
print("Direct Result {}".format(direct))
@@ -149,10 +150,10 @@ class JitLaunchTest(test.TestCase):
y = math_ops.add(x, x)
return y, y
- # Exercises compling a function (say, Foo) which calls another
- # function (say, Bar) which is not inlined. When the compiler compiles
- # Foo, it needs to symbolic execute Bar correctly regardless whether
- # Bar is inlined or not.
+ # Exercises compiling a function (say, Foo) which calls another function
+ # (say, Bar) which is not inlined. When the compiler compiles Foo, it needs
+ # to symbolically execute Bar correctly regardless of whether Bar is inlined
+ # or not.
# TODO(b/36139787): Re-enable this test when noinline works again.
# Tests compiled=True and noinline=True.
@@ -259,7 +260,7 @@ class JitLaunchTest(test.TestCase):
# TODO(phawkins): really we would like to test that there were exactly
# two kernel launches. However, we have no reliable way to determine
# that.
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
expected = np.square(np.dot(dx, dw) + db)
self.assertAllClose(expected, output, rtol=1e-1)
@@ -289,7 +290,7 @@ class XlaCompilationTest(test.TestCase):
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out)
def testIgnoredArguments(self):
@@ -313,7 +314,7 @@ class XlaCompilationTest(test.TestCase):
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(28, out)
def testLoops(self):
@@ -331,7 +332,7 @@ class XlaCompilationTest(test.TestCase):
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(result, np.float32(95), rtol=1e-1)
def testCond(self):
@@ -356,7 +357,7 @@ class XlaCompilationTest(test.TestCase):
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(result, np.float32(6), rtol=1e-1)
def testNestedFunction(self):
@@ -441,14 +442,16 @@ class XlaCompilationTest(test.TestCase):
self.assertFalse(InLabels(labels, "Log"))
self.assertTrue(InLabels(labels, "Reciprocal"))
self.assertTrue(InLabels(labels, "Mul"))
- self.assertFalse(InLabels(labels, "XlaLaunch"))
+ self.assertFalse(InLabels(labels, "XlaCompile"))
+ self.assertFalse(InLabels(labels, "XlaRun"))
- # Compile the backprop. One XlaLaunch.
+ # Compile the backprop. One XlaCompile/XlaRun pair.
labels = _Run(compiled=True)
self.assertFalse(InLabels(labels, "Log"))
self.assertFalse(InLabels(labels, "Reciprocal"))
self.assertFalse(InLabels(labels, "Mul"))
- self.assertTrue(InLabels(labels, "XlaLaunch"))
+ self.assertTrue(InLabels(labels, "XlaCompile"))
+ self.assertTrue(InLabels(labels, "XlaRun"))
class ElementWiseFusionTest(test.TestCase):
@@ -482,9 +485,12 @@ class ElementWiseFusionTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = RunMetadataLabels(run_metadata)
- count = sum("XlaLaunch(" in x for x in labels)
- return output, count
+ xla_compile_count = sum("XlaCompile(" in x for x in labels)
+ xla_run_count = sum("XlaRun(" in x for x in labels)
+ self.assertEqual(xla_compile_count, xla_run_count)
+
+ return output, xla_run_count
def testElementWiseClustering(self):
arg0 = np.random.rand(2, 2).astype(np.float32)
diff --git a/tensorflow/compiler/tests/lstm.py b/tensorflow/compiler/tests/lstm.py
index 43c469d032..73b3638e80 100644
--- a/tensorflow/compiler/tests/lstm.py
+++ b/tensorflow/compiler/tests/lstm.py
@@ -117,7 +117,7 @@ def LSTMLayer(cell_name, weights, m, c, x_seq, pad_seq):
def RandomVar(shape, name=None):
"""Returns a variable of the given shape initialized to random values."""
- return variables.Variable(
+ return variables.VariableV1(
random_ops.random_uniform(shape), dtype=dtypes.float32, name=name)
diff --git a/tensorflow/compiler/tests/quantized_ops_test.py b/tensorflow/compiler/tests/quantized_ops_test.py
new file mode 100644
index 0000000000..80c338513b
--- /dev/null
+++ b/tensorflow/compiler/tests/quantized_ops_test.py
@@ -0,0 +1,48 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for quantized operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+
+class QuantizedOpsTest(xla_test.XLATestCase):
+
+ # Verify that quantized types can be clustered by XLA.
+ def testQuantizedTypeRoundtrip(self):
+ with self.cached_session() as session:
+ for dtype in self.quantized_tf_types:
+ in_values = np.array([1, 2, 3, 4, 5, 6])
+ expected = [[1, 2], [3, 4], [5, 6]]
+ with self.test_scope():
+ p = array_ops.placeholder(dtype=dtypes.int32)
+ x = math_ops.cast(p, dtype)
+ x = array_ops.reshape(x, [3, 2])
+
+ value = session.run(x, {p: in_values})
+ self.assertAllEqual(value, expected)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index 6e18344117..36ef6ed5fe 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -35,7 +35,8 @@ class RandomOpsTest(xla_test.XLATestCase):
"""Test cases for random-number generating operators."""
def _random_types(self):
- return set(self.numeric_types) - set(self.complex_types)
+ return set(self.numeric_types) - set(
+ self.complex_types) - {np.uint8, np.int8}
def _testRngIsNotConstant(self, rng, dtype):
# Tests that 'rng' does not always return the same value.
@@ -68,9 +69,8 @@ class RandomOpsTest(xla_test.XLATestCase):
def rng(dtype):
return random_ops.random_normal(shape=[2], dtype=dtype)
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- dtype = dtypes.float32
- self._testRngIsNotConstant(rng, dtype)
+ for dtype in self._random_types() & self.float_types:
+ self._testRngIsNotConstant(rng, dtype)
def testRandomUniformIsInRange(self):
for dtype in self._random_types():
@@ -92,13 +92,13 @@ class RandomOpsTest(xla_test.XLATestCase):
def rng(dtype):
return random_ops.truncated_normal(shape=[2], dtype=dtype)
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- self._testRngIsNotConstant(rng, dtypes.float32)
+ for dtype in self._random_types() & self.float_types:
+ self._testRngIsNotConstant(rng, dtype)
def testTruncatedNormalIsInRange(self):
count = 10000000
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- for dtype in [dtypes.float32]:
+ # TODO(b/34339814): make this test work with 16 bit float types.
+ for dtype in self._random_types() & {dtypes.float32, dtypes.float64}:
with self.cached_session() as sess:
with self.test_scope():
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
@@ -144,9 +144,6 @@ class RandomOpsTest(xla_test.XLATestCase):
self.assertAllClose(actual_variance, expected_variance, rtol=2*1e-3)
def testShuffle1d(self):
- # TODO(b/26783907): this test requires the CPU backend to implement sort.
- if self.device in ["XLA_CPU"]:
- return
with self.cached_session() as sess:
with self.test_scope():
x = math_ops.range(1 << 16)
diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py
index 60c2337743..abc822ef36 100644
--- a/tensorflow/compiler/tests/reverse_sequence_op_test.py
+++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py
@@ -85,7 +85,7 @@ class ReverseSequenceTest(xla_test.XLATestCase):
def testSeqLength(self):
for dtype in self.all_types:
- for seq_dtype in self.int_types:
+ for seq_dtype in self.all_types & {np.int32, np.int64}:
self._testBasic(dtype, seq_dtype)
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index 51c04b5c47..dbf4beb693 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -48,10 +48,6 @@ class XlaSortOpTest(xla_test.XLATestCase):
self.assertAllClose(v, result, rtol=1e-3)
def testSort(self):
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
for dtype in supported_types.intersection(self.numeric_types):
x = np.arange(101, dtype=dtype)
@@ -60,10 +56,6 @@ class XlaSortOpTest(xla_test.XLATestCase):
xla.sort, [x], expected=[np.arange(101, dtype=dtype)])
def testTopK(self):
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
for dtype in supported_types.intersection(self.numeric_types):
@@ -89,10 +81,6 @@ class XlaSortOpTest(xla_test.XLATestCase):
expected=[x[indices].astype(dtype), indices])
def testTopK2D(self):
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
for dtype in supported_types.intersection(self.numeric_types):
@@ -122,10 +110,6 @@ class XlaSortOpTest(xla_test.XLATestCase):
def testTopKZeros(self):
"""Tests that positive and negative zeros sort correctly."""
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
# Only bfloat16 is implemented.
bfloat16 = dtypes.bfloat16.as_numpy_dtype
if bfloat16 not in self.numeric_types:
@@ -144,10 +128,6 @@ class XlaSortOpTest(xla_test.XLATestCase):
def testTopKInfinities(self):
"""Tests that positive and negative infinity sort correctly."""
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
# Only bfloat16 is implemented.
bfloat16 = dtypes.bfloat16.as_numpy_dtype
if bfloat16 not in self.numeric_types:
diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py
index 1bea7d9355..f3861043b2 100644
--- a/tensorflow/compiler/tests/stateless_random_ops_test.py
+++ b/tensorflow/compiler/tests/stateless_random_ops_test.py
@@ -34,7 +34,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
"""Test cases for stateless random-number generator operators."""
def _random_types(self):
- return [dtypes.float32]
+ return self.float_types & {dtypes.float32, dtypes.float64}
def testDeterminism(self):
# Stateless values should be equal iff the seeds are equal (roughly)
@@ -124,8 +124,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
self.assertTrue(self._anderson_darling(y) < 2.492)
def testTruncatedNormalIsInRange(self):
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- for dtype in [dtypes.float32]:
+ for dtype in self._random_types():
with self.cached_session() as sess, self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
n = 10000000
@@ -159,7 +158,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
# Department of Scientific Computing website. Florida State University.
expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
actual_mean = np.mean(y)
- self.assertAllClose(actual_mean, expected_mean, atol=2e-4)
+ self.assertAllClose(actual_mean, expected_mean, atol=5e-4)
expected_median = mu + probit(
(normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma
diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py
index 55a992195f..98a07709c6 100644
--- a/tensorflow/compiler/tests/ternary_ops_test.py
+++ b/tensorflow/compiler/tests/ternary_ops_test.py
@@ -122,8 +122,7 @@ class TernaryOpsTest(xla_test.XLATestCase):
expected=np.array([[2], [5]], dtype=dtype))
def testClipByValue(self):
- # TODO(b/78258593): enable integer types here too.
- for dtype in self.float_types:
+ for dtype in self.numeric_types - self.complex_types:
test_cases = [
(np.array([2, 4, 5], dtype=dtype), dtype(7)), #
(dtype(1), np.array([2, 4, 5], dtype=dtype)), #
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 5b0e57f83f..77f6eee0cf 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -84,7 +84,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
self.assertAllClose(result[i], expected[i], rtol, atol)
def testAllTypeOps(self):
- for dtype in self.numeric_types:
+ for dtype in self.numeric_types - {np.int8, np.uint8}:
self._assertOpOutputMatchesExpected(
array_ops.diag, np.array([1, 2, 3, 4], dtype=dtype),
np.array(
@@ -158,9 +158,6 @@ class UnaryOpsTest(xla_test.XLATestCase):
def testFloatOps(self):
for dtype in self.float_types:
- # TODO(b/77694432): Half test failed on CPU, last ran on 04-06-2018.
- if dtype == np.float16 and self.device == "XLA_CPU":
- continue
x = np.arange(-0.90, 0.90, 0.25)
self._assertOpOutputMatchesExpected(
math_ops.acos, x.astype(dtype), expected=np.arccos(x).astype(dtype))
@@ -633,7 +630,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
expected=np.array([-1, 0, -2, -17, -43], dtype=dtype))
def testNumericOps(self):
- for dtype in self.numeric_types:
+ for dtype in self.numeric_types - {np.int8, np.uint8}:
self._assertOpOutputMatchesExpected(
math_ops.abs,
np.array([[2, -1]], dtype=dtype),
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
index 1e600c44e9..4cf88fc523 100644
--- a/tensorflow/compiler/tests/xla_ops_test.py
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -181,7 +181,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
dtype=dtype))
def testNeg(self):
- for dtype in self.numeric_types:
+ for dtype in self.numeric_types - {np.uint8, np.int8}:
self._assertOpOutputMatchesExpected(
xla.neg,
args=(np.array([1, 2, 3], dtype=dtype),),
diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py
index 88827cb53b..98a41981cf 100644
--- a/tensorflow/compiler/tests/xla_test.py
+++ b/tensorflow/compiler/tests/xla_test.py
@@ -97,10 +97,23 @@ class XLATestCase(test.TestCase):
])
self._numeric_tf_types = set(
self.int_tf_types | self._float_tf_types | self.complex_tf_types)
-
- self._all_types = set(
- [dtype.as_numpy_dtype for dtype in self._all_tf_types])
+ self.quantized_tf_types = set(
+ dtype for dtype in self._all_tf_types if dtype.is_quantized)
+
+ # Quantized types don't have a numpy equivalent, include them in
+ # all_tf_types but not in all_types.
+ # TODO(b/115960798): Parametrize tests on TF types instead of numpy types
+ # and remove all_types.
+ self._all_types = set(dtype.as_numpy_dtype
+ for dtype in self._all_tf_types
+ if not dtype.is_quantized)
self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types])
+ self.signed_int_types = set(dtype.as_numpy_dtype
+ for dtype in self.int_tf_types
+ if not dtype.is_unsigned)
+ self.unsigned_int_types = set(dtype.as_numpy_dtype
+ for dtype in self.int_tf_types
+ if dtype.is_unsigned)
self._float_types = set(
[dtype.as_numpy_dtype for dtype in self._float_tf_types])
self.complex_types = set([
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
index 922ae7c79a..027ca6d2d2 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -29,14 +29,6 @@ Status BackwardsConstAnalysis(const Graph& g,
std::vector<bool>* compile_time_const_arg_indices,
std::vector<bool>* compile_time_const_nodes,
std::function<bool(const Edge&)> edge_filter) {
- // Operators that don't look at the data of their inputs, just the shapes.
- const std::unordered_set<string> metadata_ops = {
- "Rank",
- "Shape",
- "ShapeN",
- "Size",
- };
-
std::vector<bool> compile_time_const_nodes_impl;
if (compile_time_const_nodes) {
CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids());
@@ -50,7 +42,9 @@ Status BackwardsConstAnalysis(const Graph& g,
if (!status.ok()) return;
// If this is a metadata-only op, don't propagate the const requirement.
- if (metadata_ops.find(node->type_string()) != metadata_ops.end()) return;
+ if (XlaOpRegistry::IsMetadataOp(node->type_string())) {
+ return;
+ }
// If this node must be const, and it isn't a metadata op, then all of its
// parents must be const.
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index f792c52032..2d45507796 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -31,11 +31,13 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
+#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
@@ -89,7 +91,6 @@ Status FunctionalizeControlFlowForFunction(
}
});
const FunctionBody* body = flr->GetFunctionBody(handle);
- const FunctionDef& fdef = body->fdef;
// If any node has associated functions, functionalize them first.
// Gather nodes with associated functions first, because rewriting those nodes
@@ -108,7 +109,8 @@ Status FunctionalizeControlFlowForFunction(
auto associated_functions = iter.second;
for (auto& associated_function : associated_functions) {
string name = associated_function.func_name();
- string canonicalized_name = Canonicalize(name, AttrSlice(&attrs));
+ string canonicalized_name =
+ Canonicalize(name, AttrSlice(&associated_function.attrs()));
auto iter = canonicalized_name_to_new_name->find(canonicalized_name);
string new_name;
if (iter != canonicalized_name_to_new_name->end()) {
@@ -118,7 +120,8 @@ Status FunctionalizeControlFlowForFunction(
} else {
new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_"));
TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
- name, new_name, attrs, fld, flr, canonicalized_name_to_new_name));
+ name, new_name, associated_function.attrs(), fld, flr,
+ canonicalized_name_to_new_name));
(*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
}
// Notice that if "n" is a function call, RewriteAssociatedFunction() will
@@ -130,26 +133,54 @@ Status FunctionalizeControlFlowForFunction(
}
}
+ // Call graph optimizer. The most important optimization we need is constant
+ // folding, which will replace ops like Shape/BroadcastGradientArgs with
+ // constant shape input. Without this optimization, those ops might become
+ // dynamic input for then/else body function and XLA will complain that input
+ // is not compile time constant. We enable function inlining as well, because
+ // otherwise we won't be able to infer shape for any node depending on
+ // function call nodes.
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_before_opt_", func_name),
+ *body->graph, fld);
+ }
+ // Optimizer accepts std::unique_ptr<Graph>* as input and might change
+ // underlying pointer, thus we create a new Graph and copy from body->graph.
+ std::unique_ptr<Graph> optimized_graph(new Graph(fld));
+ CopyGraph(*body->graph, optimized_graph.get());
+ OptimizerOptions opts;
+ opts.set_opt_level(OptimizerOptions::L0);
+ opts.set_do_function_inlining(true);
+ opts.set_do_constant_folding(true);
+ GraphOptimizer optimizer(opts);
+ auto cf_consider_fn = [](const Node* n) {
+ // Skip SymbolicGradient op when doing constant folding.
+ // Enabling SymbolicGradient op in constant folding requires
+ // flr->device() to be non-null, and here we have not constructed
+ // proper Device object yet (it will be constructed in XlaCompiler).
+ return n->type_string() != FunctionLibraryDefinition::kGradientOp;
+ };
+ optimizer.Optimize(flr, flr->env(),
+ /*device=*/nullptr, &optimized_graph,
+ /*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr,
+ cf_consider_fn);
+
// Functionalize the function body.
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
- *body->graph, fld);
+ *optimized_graph, fld);
}
- TF_RETURN_IF_ERROR(FunctionalizeControlFlow(body->graph, fld));
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld));
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("functionalize_control_flow_after_fdef_", func_name),
- *body->graph, fld);
+ *optimized_graph, fld);
}
FunctionDef functionalized_fdef;
- TF_RETURN_IF_ERROR(
- GraphToFunctionDef(*body->graph, new_func_name, &functionalized_fdef));
-
- // Copy signature and ret from original FunctionDef.
- *functionalized_fdef.mutable_signature() = fdef.signature();
- *functionalized_fdef.mutable_ret() = fdef.ret();
- functionalized_fdef.mutable_signature()->set_name(new_func_name);
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name,
+ &functionalized_fdef));
// Add rewritten FunctionDef into library.
if (func_name == new_func_name) {
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 46794f7b50..3e823254d3 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -113,6 +113,7 @@ tf_kernel_library(
"shape_util.h",
],
deps = [
+ ":conv_op_helpers",
":if_op",
":while_op",
"//tensorflow/compiler/tf2xla:common",
@@ -172,6 +173,27 @@ tf_kernel_library(
],
)
+cc_library(
+ name = "conv_op_helpers",
+ srcs = ["conv_op_helpers.cc"],
+ hdrs = ["conv_op_helpers.h"],
+ deps = [
+ "//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/lib:constants",
+ "//tensorflow/compiler/xla/client/lib:numeric",
+ "//tensorflow/core:framework",
+ "//tensorflow/core/kernels:bounds_check",
+ "//tensorflow/core/kernels:conv_ops",
+ "//tensorflow/core/kernels:ops_util",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
tf_kernel_library(
name = "while_op",
srcs = ["while_op.cc"],
diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
index b3ad0aea84..a267c0c72f 100644
--- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
@@ -34,12 +34,6 @@ class FusedBatchNormOp : public XlaOpKernel {
OP_REQUIRES(
ctx, FormatFromString(data_format_str, &data_format_),
errors::InvalidArgument("Invalid data format: ", data_format_str));
- OP_REQUIRES(ctx,
- (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW ||
- data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN),
- errors::InvalidArgument(
- "Unsupported data format ", ToString(data_format_),
- "; supported formats are NHWC, NCHW, HWNC and HWCN"));
}
void Compile(XlaOpKernelContext* ctx) override {
@@ -110,12 +104,6 @@ class FusedBatchNormGradOp : public XlaOpKernel {
OP_REQUIRES(
ctx, FormatFromString(data_format_str, &data_format_),
errors::InvalidArgument("Invalid data format: ", data_format_str));
- OP_REQUIRES(ctx,
- (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW ||
- data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN),
- errors::InvalidArgument(
- "Unsupported data format ", ToString(data_format_),
- "; supported formats are NHWC, NCHW, HWNC and HWCN"));
}
void Compile(XlaOpKernelContext* ctx) override {
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index 0d9a768a6f..a988d3c33e 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -55,6 +56,24 @@ XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
+// Implementation of DivNoNan. Pseudo-code:
+// if (y == 0) {
+// return 0
+// } else {
+// return x / y;
+// }
+static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
+ xla::XlaOp y, const BCast& broadcast_helper) {
+ std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ auto zero = XlaHelpers::Zero(b, dtype);
+ auto y_equals_0 = xla::Eq(y, zero);
+ auto zeros = xla::ZerosLike(x);
+ auto result = xla::Select(y_equals_0, zeros, xla::Div(x, y));
+ return result;
+}
+XLA_MAKE_BINARY(DivNoNan,
+ DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+
// Implementation of FloorDiv. Pseudo-code:
// if ((x < 0) != (y < 0)) {
// T abs_x = std::abs(x);
@@ -84,6 +103,24 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
XLA_MAKE_BINARY(FloorDiv,
FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+static xla::XlaOp XlogyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
+ xla::XlaOp y, const BCast& broadcast_helper) {
+ std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ auto zero = XlaHelpers::Zero(b, dtype);
+ auto is_zero = xla::Eq(x, zero);
+ return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y)));
+}
+XLA_MAKE_BINARY(Xlogy, XlogyImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+
+static xla::XlaOp XdivyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
+ xla::XlaOp y, const BCast& broadcast_helper) {
+ std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ auto zero = XlaHelpers::Zero(b, dtype);
+ auto is_zero = xla::Eq(x, zero);
+ return xla::Select(is_zero, zero, xla::Div(x, y));
+}
+XLA_MAKE_BINARY(Xdivy, XdivyImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+
// Implementation of FloorMod. Pseudo-code:
// T trunc_mod = std::fmod(x, y);
// return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y);
diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
index 4bd7c74dca..696c1c39be 100644
--- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
@@ -64,10 +64,9 @@ class BroadcastToOp : public XlaOpKernel {
output_shape.DebugString()));
broadcast_dims.push_back(broadcast_shape.size());
- if (output_dims[i] == input_dims[i] || input_dims[i] == 1) {
+ if (output_dims[i] == input_dims[i]) {
broadcast_shape.push_back(output_dims[i]);
- }
- if (output_dims[i] != input_dims[i]) {
+ } else if (output_dims[i] != input_dims[i]) {
// Add dimensions [I, O/I], which we will later flatten to just
// [O]. We must do this in two phases since XLA broadcasting does not
// support tiling.
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
new file mode 100644
index 0000000000..c9a1be4940
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
@@ -0,0 +1,509 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Ops for 2D convolution.
+
+#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
+#include "absl/types/span.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+namespace {
+
+// Returns the expanded size of a filter used for depthwise convolution.
+// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N].
+xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) {
+ int num_dims = shape.dimensions_size();
+ CHECK_GE(num_dims, 2); // Crash OK
+ xla::Shape expanded_shape = shape;
+ expanded_shape.set_dimensions(
+ num_dims - 1,
+ shape.dimensions(num_dims - 2) * shape.dimensions(num_dims - 1));
+ return expanded_shape;
+}
+
+// Create a mask for depthwise convolution that will make a normal convolution
+// produce the same results as a depthwise convolution. For a [2, 2, 3, 2]
+// depthwise filter this returns a [2, 2, 3, 6] tensor
+// 1 1 0 0 0 0 1 1 0 0 0 0
+// 0 0 1 1 0 0 0 0 1 1 0 0
+// 0 0 0 0 1 1 0 0 0 0 1 1
+//
+// 1 1 0 0 0 0 1 1 0 0 0 0
+// 0 0 1 1 0 0 0 0 1 1 0 0
+// 0 0 0 0 1 1 0 0 0 0 1 1
+//
+// The first step is to create a one tensor, A, that is [3]
+// 0 1 2
+//
+// and another tensor, B, that is [3 * 2]
+// 0 1 2 3 4 5
+//
+// and divide B it by 2 to get
+// 0 0 1 1 2 2
+//
+// then we broadcast the B to [2, 2, 3, 3 * 2]
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+//
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+//
+// Finally compare A and broadcasted B in dimension 2 amd return the result at
+// the beginning of the comment.
+xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape,
+ xla::XlaBuilder* builder) {
+ xla::Shape expanded_filter_shape =
+ ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
+ int64 depthwise_multiplier =
+ filter_shape.dimensions(filter_shape.dimensions_size() - 1);
+ int64 input_feature =
+ filter_shape.dimensions(filter_shape.dimensions_size() - 2);
+
+ // Create a M sized linspace and an M*N sized linspace that will be
+ // broadcasted into perpendicular dimensions and compared.
+ xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature);
+ xla::XlaOp expanded_feature_iota =
+ xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier);
+
+ // Divide the M*N sized linspace by the depthwise_multiplier to create
+ // [0 0 1 1 2 2] in the example in the function comment.
+ expanded_feature_iota =
+ xla::Div(expanded_feature_iota,
+ XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
+ depthwise_multiplier));
+
+ // Broadcast the N*M linspace to [H, W, ..., M, M*N].
+ std::vector<int64> expanded_feature_broadcast_dims(
+ expanded_filter_shape.dimensions().begin(),
+ expanded_filter_shape.dimensions().end());
+ expanded_feature_broadcast_dims.pop_back();
+ auto broadcasted_expanded_feature_iota =
+ xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims);
+
+ // Compare the broadcasted linspace to the input feature linspace in the
+ // input feature dimension to create a diagonal predicate.
+ return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota,
+ {expanded_filter_shape.dimensions_size() - 2});
+}
+
+// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
+// build a depthwise convolution.
+xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape,
+ const xla::XlaOp& filter) {
+ int64 input_feature_dim = filter_shape.dimensions_size() - 2;
+ int64 output_feature_dim = filter_shape.dimensions_size() - 1;
+ int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim);
+ int64 input_feature = filter_shape.dimensions(input_feature_dim);
+
+ // Create a [H, W, ..., 1, N*M] reshape of the filter.
+ xla::Shape implicit_broadcast_filter_shape = filter_shape;
+ implicit_broadcast_filter_shape.set_dimensions(input_feature_dim, 1);
+ implicit_broadcast_filter_shape.set_dimensions(
+ output_feature_dim, depthwise_multiplier * input_feature);
+ return xla::Reshape(
+ filter, xla::AsInt64Slice(implicit_broadcast_filter_shape.dimensions()));
+}
+
+// Reduces the results of the convolution with an expanded filter to the
+// non-expanded filter.
+xla::XlaOp ContractFilterForDepthwiseBackprop(const xla::Shape& filter_shape,
+ const xla::XlaOp& filter_backprop,
+ xla::XlaBuilder* builder) {
+ auto masked_expanded_filter =
+ xla::Select(CreateExpandedFilterMask(filter_shape, builder),
+ filter_backprop, xla::ZerosLike(filter_backprop));
+
+ auto elem_type = filter_shape.element_type();
+ return xla::Reshape(
+ // This reduce does not need inputs to be converted with
+ // XlaHelpers::SumAccumulationType() since the select above guarantees
+ // that only one element is non zero, so there cannot be accumulated
+ // precision error.
+ xla::Reduce(masked_expanded_filter, xla::Zero(builder, elem_type),
+ CreateScalarAddComputation(elem_type, builder),
+ {filter_shape.dimensions_size() - 2}),
+ xla::AsInt64Slice(filter_shape.dimensions()));
+}
+
+// Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA
+// convolutions (as currently implemented).
+Status CheckConvAttrs(const ConvOpAttrs& attrs) {
+ const int num_dims = attrs.num_spatial_dims + 2;
+ if (attrs.strides.size() != num_dims) {
+ return errors::InvalidArgument("Sliding window strides field must specify ",
+ num_dims, " dimensions");
+ }
+ int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+ int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+ if (attrs.strides[batch_dim] != 1 || attrs.strides[feature_dim] != 1) {
+ return errors::Unimplemented(
+ "Current implementation does not yet support strides in the batch and "
+ "depth dimensions.");
+ }
+ if (attrs.dilations.size() != num_dims) {
+ return errors::InvalidArgument("Dilations field must specify ", num_dims,
+ " dimensions");
+ }
+ if (attrs.dilations[batch_dim] != 1 || attrs.dilations[feature_dim] != 1) {
+ return errors::Unimplemented(
+ "Current implementation does not support dilations in the batch and "
+ "depth dimensions.");
+ }
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ int input_dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+ if (attrs.dilations[input_dim] < 1) {
+ return errors::Unimplemented("Dilation values must be positive; ", i,
+ "th spatial dimension had dilation ",
+ attrs.dilations[input_dim]);
+ }
+ }
+ return Status::OK();
+}
+
+// Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes
+// to TensorShapes.
+Status ConvBackpropComputeDimensionsV2XlaShapes(
+ StringPiece label, int num_spatial_dims, const xla::Shape& input_shape,
+ const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape,
+ absl::Span<const int32> dilations, const std::vector<int32>& strides,
+ Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims) {
+ TensorShape input_tensor_shape, filter_tensor_shape,
+ out_backprop_tensor_shape;
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
+ TF_RETURN_IF_ERROR(
+ XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape));
+ return ConvBackpropComputeDimensionsV2(
+ label, num_spatial_dims, input_tensor_shape, filter_tensor_shape,
+ out_backprop_tensor_shape, dilations, strides, padding, data_format,
+ dims);
+}
+
+} // anonymous namespace
+
+xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims,
+ bool depthwise,
+ OpKernelConstruction* ctx) {
+ ConvOpAttrs attrs;
+ attrs.num_spatial_dims = num_spatial_dims;
+ attrs.depthwise = depthwise;
+ TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding));
+
+ string data_format;
+ TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format));
+ if (!FormatFromString(data_format, &attrs.data_format)) {
+ return errors::InvalidArgument("Invalid data format: ", data_format);
+ }
+
+ return attrs;
+}
+
+xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
+ xla::XlaOp conv_input,
+ xla::XlaOp filter,
+ const ConvOpAttrs& attrs) {
+ TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
+
+ auto* builder = conv_input.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(conv_input));
+ // Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth]
+ TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
+
+ // For 2D convolution, there should be 4 dimensions.
+ int num_dims = attrs.num_spatial_dims + 2;
+ if (input_shape.dimensions_size() != num_dims) {
+ return errors::InvalidArgument("input must be ", num_dims, "-dimensional",
+ input_shape.DebugString());
+ }
+ if (filter_shape.dimensions_size() != num_dims) {
+ return errors::InvalidArgument(
+ "filter must be ", num_dims,
+ "-dimensional: ", filter_shape.DebugString());
+ }
+
+ // The last two dimensions of the filter are the input and output shapes.
+ int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+ int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+
+ int64 in_depth = filter_shape.dimensions(attrs.num_spatial_dims);
+ // The 'C' dimension for input is in_depth. It must be the same as
+ // the filter's in_depth.
+ if (in_depth != input_shape.dimensions(feature_dim)) {
+ return errors::InvalidArgument(
+ "input and filter must have the same depth: ", in_depth, " vs ",
+ input_shape.dimensions(feature_dim));
+ }
+
+ if (attrs.depthwise) {
+ filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter);
+ }
+
+ xla::ConvolutionDimensionNumbers dims;
+ std::vector<int64> window_strides(attrs.num_spatial_dims);
+ std::vector<int64> lhs_dilation(attrs.num_spatial_dims, 1);
+ std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
+ std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
+
+ dims.set_input_batch_dimension(batch_dim);
+ dims.set_output_batch_dimension(batch_dim);
+ dims.set_input_feature_dimension(feature_dim);
+ dims.set_output_feature_dimension(feature_dim);
+ dims.set_kernel_input_feature_dimension(attrs.num_spatial_dims);
+ dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1);
+
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+ dims.add_input_spatial_dimensions(dim);
+ dims.add_kernel_spatial_dimensions(i);
+ dims.add_output_spatial_dimensions(dim);
+ window_strides[i] = attrs.strides.at(dim);
+ rhs_dilation[i] = attrs.dilations.at(dim);
+
+ int64 unused_output_size;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
+ input_shape.dimensions(dim), filter_shape.dimensions(i),
+ rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size,
+ &padding[i].first, &padding[i].second));
+ }
+
+ return xla::ConvGeneralDilated(
+ conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
+ dims, /*feature_group_count=*/attrs.depthwise ? in_depth : 1);
+}
+
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
+ StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
+ xla::XlaOp out_backprop, const ConvOpAttrs& attrs) {
+ TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
+
+ int num_dims = attrs.num_spatial_dims + 2;
+ int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+ int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+
+ auto* builder = filter.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
+ TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
+ builder->GetShape(out_backprop));
+
+ xla::Shape expanded_filter_shape =
+ attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
+ : filter_shape;
+ // Reuse dimension computation logic from conv_grad_ops.cc.
+ ConvBackpropDimensions dims;
+ TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
+ type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape,
+ out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding,
+ attrs.data_format, &dims));
+
+ // The input gradients are computed by a convolution of the output
+ // gradients and the filter, with some appropriate padding. See the
+ // comment at the top of conv_grad_ops.h for details.
+
+ xla::ConvolutionDimensionNumbers dnums;
+ dnums.set_input_batch_dimension(batch_dim);
+ dnums.set_output_batch_dimension(batch_dim);
+ dnums.set_input_feature_dimension(feature_dim);
+ dnums.set_output_feature_dimension(feature_dim);
+
+ // TF filter shape is [ H, W, ..., inC, outC ]
+ // Transpose the input and output features for computing the gradient.
+ dnums.set_kernel_input_feature_dimension(attrs.num_spatial_dims + 1);
+ dnums.set_kernel_output_feature_dimension(attrs.num_spatial_dims);
+
+ std::vector<int64> kernel_spatial_dims(attrs.num_spatial_dims);
+ std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
+ std::vector<int64> lhs_dilation(attrs.num_spatial_dims);
+ std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
+ std::vector<int64> ones(attrs.num_spatial_dims, 1);
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+ dnums.add_input_spatial_dimensions(dim);
+ dnums.add_kernel_spatial_dimensions(i);
+ dnums.add_output_spatial_dimensions(dim);
+
+ kernel_spatial_dims[i] = i;
+ padding[i] = {dims.spatial_dims[i].pad_before,
+ dims.spatial_dims[i].pad_after};
+ lhs_dilation[i] = dims.spatial_dims[i].stride;
+ rhs_dilation[i] = attrs.dilations[dim];
+ }
+
+ // Mirror the filter in the spatial dimensions.
+ xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims);
+
+ // activation gradients
+ // = gradients (with padding and dilation) <conv> mirrored_weights
+ return xla::ConvGeneralDilated(
+ out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
+ lhs_dilation, rhs_dilation, dnums,
+ /*feature_group_count=*/
+ attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) /
+ filter_shape.dimensions(attrs.num_spatial_dims + 1)
+ : 1);
+}
+
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
+ StringPiece type_string, xla::XlaOp activations,
+ const xla::Shape& filter_shape, xla::XlaOp gradients,
+ const ConvOpAttrs& attrs) {
+ TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
+
+ auto* builder = activations.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape activations_shape,
+ builder->GetShape(activations));
+ TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
+ builder->GetShape(gradients));
+ const xla::Shape expanded_filter_shape =
+ attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
+ : filter_shape;
+
+ // Reuse dimension computation logic from conv_grad_ops.cc.
+ ConvBackpropDimensions dims;
+ TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
+ type_string, attrs.num_spatial_dims, activations_shape,
+ expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides,
+ attrs.padding, attrs.data_format, &dims));
+
+ // The filter gradients are computed by a convolution of the input
+ // activations and the output gradients, with some appropriate padding.
+ // See the comment at the top of conv_grad_ops.h for details.
+
+ xla::ConvolutionDimensionNumbers dnums;
+
+ // The activations (inputs) form the LHS of the convolution.
+ // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
+ // For the gradient computation, we flip the roles of the batch and
+ // feature dimensions.
+ // Each spatial entry has size in_depth * batch
+
+ // The last two dimensions of the filter are the input and output shapes.
+ int num_dims = attrs.num_spatial_dims + 2;
+ int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+ int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+
+ // Swap n_dim and c_dim in the activations.
+ dnums.set_input_batch_dimension(c_dim);
+ dnums.set_input_feature_dimension(n_dim);
+
+ // The gradients become the RHS of the convolution.
+ // The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
+ // where the batch becomes the input feature for the convolution.
+ dnums.set_kernel_input_feature_dimension(n_dim);
+ dnums.set_kernel_output_feature_dimension(c_dim);
+
+ std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
+ std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
+ std::vector<int64> window_strides(attrs.num_spatial_dims);
+ std::vector<int64> ones(attrs.num_spatial_dims, 1);
+
+ // Tensorflow filter shape is [ H, W, ..., inC, outC ].
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ dnums.add_output_spatial_dimensions(i);
+ }
+ dnums.set_output_batch_dimension(attrs.num_spatial_dims);
+ dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1);
+
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+ dnums.add_input_spatial_dimensions(dim);
+ dnums.add_kernel_spatial_dimensions(dim);
+
+ // We will also need to pad the input with zeros such that after the
+ // convolution, we get the right size for the filter.
+ // The padded_in_rows should be such that when we convolve this with the
+ // expanded_out_rows as a filter, we should get filter_rows back.
+ //
+ const int64 padded_in_size =
+ dims.spatial_dims[i].expanded_output_size +
+ (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim];
+
+ // However it can be smaller than input_rows: in this
+ // case it means some of the inputs are not used.
+ //
+ // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
+ //
+ // INPUT = [ A B C ]
+ //
+ // FILTER = [ x y ]
+ //
+ // and the output will only have one column: a = A * x + B * y
+ //
+ // and input "C" is not used at all.
+ //
+ // We apply negative padding in this case.
+ const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size;
+
+ // + For the VALID padding, we don't pad anything on the top/left side
+ // and pad the bottom/right side with the remaining space.
+ // + For the SAME padding, we pad top/left side the same as bottom/right
+ // side.
+ //
+ // In addition, if the padded input size is smaller than the input size,
+ // we need to ignore some training elements of the input. We do this by
+ // applying negative padding on the right/bottom.
+ const int64 pad_before =
+ attrs.padding == Padding::SAME ? std::max<int64>(pad_total / 2, 0) : 0;
+
+ padding[i] = {pad_before, pad_total - pad_before};
+ rhs_dilation[i] = dims.spatial_dims[i].stride;
+ window_strides[i] = attrs.dilations[dim];
+ }
+
+ // Besides padding the input, we will also expand output_rows to
+ // expanded_out_rows = (output_rows - 1) * stride + 1
+ // with zeros in between:
+ //
+ // a . . . b . . . c . . . d . . . e
+ //
+ // This is done by specifying the window dilation factors in the
+ // convolution HLO below.
+ auto filter_backprop =
+ xla::ConvGeneralDilated(activations, gradients, window_strides, padding,
+ /*lhs_dilation=*/ones, rhs_dilation, dnums);
+
+ if (attrs.depthwise) {
+ filter_backprop = ContractFilterForDepthwiseBackprop(
+ filter_shape, filter_backprop, activations.builder());
+ }
+
+ return filter_backprop;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h
new file mode 100644
index 0000000000..6e1b70a478
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h
@@ -0,0 +1,69 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_
+#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+// This header exposes utilities for translating TensorFlow convolution ops into
+// XLA ops.
+//
+// conv_ops.cc contains lowerings for many of these TF convolution ops (e.g.
+// Conv2D, Conv3DBackpropFilterV2), but you might want to use the utilities in
+// this header to implement a new and exciting convolution op, for example a
+// fused TensorFlow op that contains a convolution and other things.
+
+namespace tensorflow {
+
+// ConvOpAttrs contains all of the metadata necessary to specify a TF or XLA
+// convolution.
+struct ConvOpAttrs {
+ // Constructs a ConvOpAttrs, reading most of the attributes from `ctx`.
+ static xla::StatusOr<ConvOpAttrs> Create(int num_spatial_dims, bool depthwise,
+ OpKernelConstruction* ctx);
+
+ bool depthwise;
+ int num_spatial_dims;
+ std::vector<int32> dilations;
+ std::vector<int32> strides;
+ Padding padding;
+ TensorFormat data_format;
+};
+
+// Creates a new XLA forward or backward convolution with the given inputs and
+// attributes.
+xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece type_string,
+ xla::XlaOp conv_input,
+ xla::XlaOp filter,
+ const ConvOpAttrs& attrs);
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
+ StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
+ xla::XlaOp out_backprop, const ConvOpAttrs& attrs);
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
+ StringPiece type_string, xla::XlaOp activations,
+ const xla::Shape& filter_shape, xla::XlaOp gradients,
+ const ConvOpAttrs& attrs);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index 674720e22f..cd7c820be0 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -15,12 +15,17 @@ limitations under the License.
// XLA-specific Ops for 2D convolution.
+#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@@ -33,250 +38,28 @@ limitations under the License.
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
-
namespace {
-// Returns the expanded size of a filter used for depthwise convolution.
-// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N].
-TensorShape ExpandedFilterShapeForDepthwiseConvolution(
- const TensorShape& shape) {
- int num_dims = shape.dims();
- CHECK_GE(num_dims, 2);
- TensorShape expanded_shape = shape;
- expanded_shape.set_dim(num_dims - 1, shape.dim_size(num_dims - 2) *
- shape.dim_size(num_dims - 1));
- return expanded_shape;
-}
-
-// Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution.
-xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype,
- xla::XlaBuilder* builder) {
- TensorShape expanded_filter_shape =
- ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
- return xla::Broadcast(XlaHelpers::Zero(builder, dtype),
- expanded_filter_shape.dim_sizes());
-}
-
-// Create a mask for depthwise convolution that will make a normal convolution
-// produce the same results as a depthwise convolution. For a [2, 2, 3, 2]
-// depthwise filter this returns a [2, 2, 3, 6] tensor
-// 1 1 0 0 0 0 1 1 0 0 0 0
-// 0 0 1 1 0 0 0 0 1 1 0 0
-// 0 0 0 0 1 1 0 0 0 0 1 1
-//
-// 1 1 0 0 0 0 1 1 0 0 0 0
-// 0 0 1 1 0 0 0 0 1 1 0 0
-// 0 0 0 0 1 1 0 0 0 0 1 1
-//
-// The first step is to create a one tensor, A, that is [3]
-// 0 1 2
-//
-// and another tensor, B, that is [3 * 2]
-// 0 1 2 3 4 5
-//
-// and divide B it by 2 to get
-// 0 0 1 1 2 2
-//
-// then we broadcast the B to [2, 2, 3, 3 * 2]
-// 0 0 1 1 2 2 0 0 1 1 2 2
-// 0 0 1 1 2 2 0 0 1 1 2 2
-// 0 0 1 1 2 2 0 0 1 1 2 2
-//
-// 0 0 1 1 2 2 0 0 1 1 2 2
-// 0 0 1 1 2 2 0 0 1 1 2 2
-// 0 0 1 1 2 2 0 0 1 1 2 2
-//
-// Finally compare A and broadcasted B in dimension 2 amd return the result at
-// the beginning of the comment.
-xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape,
- xla::XlaBuilder* builder) {
- TensorShape expanded_filter_shape =
- ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
- int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1);
- int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2);
-
- // Create a M sized linspace and an M*N sized linspace that will be
- // broadcasted into perpendicular dimensions and compared.
- xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature);
- xla::XlaOp expanded_feature_iota =
- xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier);
-
- // Divide the M*N sized linspace by the depthwise_multiplier to create
- // [0 0 1 1 2 2] in the example in the function comment.
- expanded_feature_iota =
- xla::Div(expanded_feature_iota,
- XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
- depthwise_multiplier));
-
- // Broadcast the N*M linspace to [H, W, ..., M, M*N].
- auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes();
- expanded_feature_broadcast_dims.pop_back();
- auto broadcasted_expanded_feature_iota =
- xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims);
-
- // Compare the broadcasted linspace to the input feature linspace in the
- // input feature dimension to create a diagonal predicate.
- return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota,
- {expanded_filter_shape.dims() - 2});
-}
-
-// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
-// build a depthwise convolution.
-xla::XlaOp ReshapeFilterForDepthwiseConvolution(const TensorShape& filter_shape,
- const xla::XlaOp& filter) {
- int64 input_feature_dim = filter_shape.dims() - 2;
- int64 output_feature_dim = filter_shape.dims() - 1;
- int64 depthwise_multiplier = filter_shape.dim_size(output_feature_dim);
- int64 input_feature = filter_shape.dim_size(input_feature_dim);
-
- // Create a [H, W, ..., 1, N*M] reshape of the filter.
- TensorShape implicit_broadcast_filter_shape = filter_shape;
- implicit_broadcast_filter_shape.set_dim(input_feature_dim, 1);
- implicit_broadcast_filter_shape.set_dim(output_feature_dim,
- depthwise_multiplier * input_feature);
- return xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes());
-}
-
-// Reduces the results of the convolution with an expanded filter to the
-// non-expanded filter.
-xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx,
- const TensorShape& filter_shape,
- DataType dtype,
- const xla::XlaOp& filter_backprop,
- xla::XlaBuilder* builder) {
- auto masked_expanded_filter = xla::Select(
- CreateExpandedFilterMask(filter_shape, builder), filter_backprop,
- CreateExpandedZero(filter_shape, dtype, builder));
- return xla::Reshape(
- // This reduce does not need inputs to be converted with
- // XlaHelpers::SumAccumulationType() since the ExpandedFilterMask with
- // ExpandedZero guarantees that only one element is non zero, so there
- // cannot be accumulated precision error.
- xla::Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype),
- *ctx->GetOrCreateAdd(dtype), {filter_shape.dims() - 2}),
- filter_shape.dim_sizes());
-}
-
class ConvOp : public XlaOpKernel {
public:
explicit ConvOp(OpKernelConstruction* ctx, int num_spatial_dims,
bool depthwise)
- : XlaOpKernel(ctx),
- num_spatial_dims_(num_spatial_dims),
- depthwise_(depthwise) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
-
- string data_format;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
- OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
- errors::InvalidArgument("Invalid data format"));
+ : XlaOpKernel(ctx) {
+ xla::StatusOr<ConvOpAttrs> attrs =
+ ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
+ OP_REQUIRES_OK(ctx, attrs.status());
+ attrs_ = attrs.ValueOrDie();
}
- int num_dims() const { return num_spatial_dims_ + 2; }
-
void Compile(XlaOpKernelContext* ctx) override {
- OP_REQUIRES(ctx, strides_.size() == num_dims(),
- errors::InvalidArgument("Sliding window strides field must "
- "specify ",
- num_dims(), " dimensions"));
- int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_);
- int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
- OP_REQUIRES(
- ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
- errors::Unimplemented("Current implementation does not yet support "
- "strides in the batch and depth dimensions."));
-
- OP_REQUIRES(ctx, dilations_.size() == num_dims(),
- errors::InvalidArgument("Dilations field must "
- "specify ",
- num_dims(), " dimensions"));
- OP_REQUIRES(
- ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
- errors::Unimplemented("Current implementation does not support "
- "dilations in the batch and depth dimensions."));
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
- errors::Unimplemented("Dilation values must be positive; ", i,
- "th spatial dimension had dilation ",
- dilations_[input_dim]));
- }
-
- const TensorShape input_shape = ctx->InputShape(0);
- // Input filter is of the following dimensions:
- // [ filter_rows, filter_cols, ..., in_depth, out_depth]
- const TensorShape filter_shape = ctx->InputShape(1);
-
- // For 2D convolution, there should be 4 dimensions.
- OP_REQUIRES(
- ctx, input_shape.dims() == num_dims(),
- errors::InvalidArgument("input must be ", num_dims(), "-dimensional",
- input_shape.DebugString()));
- OP_REQUIRES(
- ctx, filter_shape.dims() == num_dims(),
- errors::InvalidArgument("filter must be ", num_dims(),
- "-dimensional: ", filter_shape.DebugString()));
-
- // The last two dimension of the filter are the input and output shapes.
- const int64 in_depth = filter_shape.dim_size(num_spatial_dims_);
-
- // The 'C' dimension for input is in_depth. It must be the same as
- // the filter's in_depth.
- OP_REQUIRES(ctx, in_depth == input_shape.dim_size(feature_dim),
- errors::InvalidArgument(
- "input and filter must have the same depth: ", in_depth,
- " vs ", input_shape.dim_size(feature_dim)));
-
- xla::XlaOp filter = ctx->Input(1);
- if (depthwise_) {
- filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter);
- }
-
- xla::ConvolutionDimensionNumbers dims;
- std::vector<int64> window_strides(num_spatial_dims_);
- std::vector<int64> lhs_dilation(num_spatial_dims_, 1);
- std::vector<int64> rhs_dilation(num_spatial_dims_);
- std::vector<std::pair<int64, int64>> padding(num_spatial_dims_);
-
- dims.set_input_batch_dimension(batch_dim);
- dims.set_output_batch_dimension(batch_dim);
- dims.set_input_feature_dimension(feature_dim);
- dims.set_output_feature_dimension(feature_dim);
- dims.set_kernel_input_feature_dimension(num_spatial_dims_);
- dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1);
-
- for (int i = 0; i < num_spatial_dims_; ++i) {
- const int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- dims.add_input_spatial_dimensions(dim);
- dims.add_kernel_spatial_dimensions(i);
- dims.add_output_spatial_dimensions(dim);
- window_strides[i] = strides_.at(dim);
- rhs_dilation[i] = dilations_.at(dim);
-
- int64 unused_output_size;
- OP_REQUIRES_OK(
- ctx, GetWindowedOutputSizeVerboseV2(
- input_shape.dim_size(dim), filter_shape.dim_size(i),
- rhs_dilation[i], window_strides[i], padding_,
- &unused_output_size, &padding[i].first, &padding[i].second));
- }
-
- xla::XlaOp conv = xla::ConvGeneralDilated(
- ctx->Input(0), filter, window_strides, padding, lhs_dilation,
- rhs_dilation, dims,
- /*feature_group_count=*/depthwise_ ? in_depth : 1);
- ctx->SetOutput(0, conv);
+ xla::StatusOr<xla::XlaOp> conv = MakeXlaForwardConvOp(
+ ctx->op_kernel().type_string(), ctx->Input(0), ctx->Input(1), attrs_);
+ OP_REQUIRES_OK(ctx, conv.status());
+ ctx->SetOutput(0, conv.ValueOrDie());
}
protected:
- const int num_spatial_dims_;
- const bool depthwise_;
- std::vector<int32> dilations_;
- std::vector<int32> strides_;
- Padding padding_;
- TensorFormat data_format_ = FORMAT_NHWC;
+ ConvOpAttrs attrs_;
private:
TF_DISALLOW_COPY_AND_ASSIGN(ConvOp);
@@ -308,124 +91,28 @@ class ConvBackpropInputOp : public XlaOpKernel {
public:
explicit ConvBackpropInputOp(OpKernelConstruction* ctx, int num_spatial_dims,
bool depthwise)
- : XlaOpKernel(ctx),
- num_spatial_dims_(num_spatial_dims),
- depthwise_(depthwise) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
- string data_format;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
- OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
- errors::InvalidArgument("Invalid data format"));
+ : XlaOpKernel(ctx) {
+ xla::StatusOr<ConvOpAttrs> attrs =
+ ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
+ OP_REQUIRES_OK(ctx, attrs.status());
+ attrs_ = attrs.ValueOrDie();
}
- int num_dims() const { return num_spatial_dims_ + 2; }
-
void Compile(XlaOpKernelContext* ctx) override {
- OP_REQUIRES(ctx, strides_.size() == num_dims(),
- errors::InvalidArgument("Sliding window strides field must "
- "specify ",
- num_dims(), " dimensions"));
- int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_);
- int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
- OP_REQUIRES(
- ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
- errors::Unimplemented("Current implementation does not yet support "
- "strides in the batch and depth dimensions."));
-
- OP_REQUIRES(ctx, dilations_.size() == num_dims(),
- errors::InvalidArgument("Dilations field must "
- "specify ",
- num_dims(), " dimensions"));
- OP_REQUIRES(
- ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
- errors::Unimplemented("Current implementation does not support "
- "dilations in the batch and depth dimensions."));
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
- errors::Unimplemented("Dilation values must be positive; ", i,
- "th spatial dimension had dilation ",
- dilations_[input_dim]));
- }
-
- TensorShape input_shape;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape));
-
- const TensorShape filter_shape = ctx->InputShape(1);
- const TensorShape out_backprop_shape = ctx->InputShape(2);
-
- const TensorShape expanded_filter_shape =
- depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
- : filter_shape;
- // Reuse dimension computation logic from conv_grad_ops.cc.
- ConvBackpropDimensions dims;
- OP_REQUIRES_OK(ctx,
- ConvBackpropComputeDimensionsV2(
- type_string(), num_spatial_dims_, input_shape,
- expanded_filter_shape, out_backprop_shape, dilations_,
- strides_, padding_, data_format_, &dims));
-
- auto filter = ctx->Input(1);
- auto out_backprop = ctx->Input(2);
-
- // The input gradients are computed by a convolution of the output
- // gradients and the filter, with some appropriate padding. See the
- // comment at the top of conv_grad_ops.h for details.
-
- xla::ConvolutionDimensionNumbers dnums;
- dnums.set_input_batch_dimension(batch_dim);
- dnums.set_output_batch_dimension(batch_dim);
- dnums.set_input_feature_dimension(feature_dim);
- dnums.set_output_feature_dimension(feature_dim);
-
- // TF filter shape is [ H, W, ..., inC, outC ]
- // Transpose the input and output features for computing the gradient.
- dnums.set_kernel_input_feature_dimension(num_spatial_dims_ + 1);
- dnums.set_kernel_output_feature_dimension(num_spatial_dims_);
-
- std::vector<int64> kernel_spatial_dims(num_spatial_dims_);
- std::vector<std::pair<int64, int64>> padding(num_spatial_dims_);
- std::vector<int64> lhs_dilation(num_spatial_dims_);
- std::vector<int64> rhs_dilation(num_spatial_dims_);
- std::vector<int64> ones(num_spatial_dims_, 1);
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- dnums.add_input_spatial_dimensions(dim);
- dnums.add_kernel_spatial_dimensions(i);
- dnums.add_output_spatial_dimensions(dim);
-
- kernel_spatial_dims[i] = i;
- padding[i] = {dims.spatial_dims[i].pad_before,
- dims.spatial_dims[i].pad_after};
- lhs_dilation[i] = dims.spatial_dims[i].stride;
- rhs_dilation[i] = dilations_[dim];
- }
-
- // Mirror the filter in the spatial dimensions.
- xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims);
-
- // activation gradients
- // = gradients (with padding and dilation) <conv> mirrored_weights
- xla::XlaOp in_backprop = xla::ConvGeneralDilated(
- out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
- lhs_dilation, rhs_dilation, dnums,
- /*feature_group_count=*/
- depthwise_ ? out_backprop_shape.dim_size(feature_dim) /
- filter_shape.dim_size(num_spatial_dims_ + 1)
- : 1);
-
- ctx->SetOutput(0, in_backprop);
+ TensorShape input_tensor_shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape));
+ xla::Shape input_shape =
+ TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape);
+
+ xla::StatusOr<xla::XlaOp> in_backprop =
+ MakeXlaBackpropInputConvOp(ctx->op_kernel().type_string(), input_shape,
+ ctx->Input(1), ctx->Input(2), attrs_);
+ OP_REQUIRES_OK(ctx, in_backprop.status());
+ ctx->SetOutput(0, in_backprop.ValueOrDie());
}
protected:
- const int num_spatial_dims_;
- const bool depthwise_;
- std::vector<int32> dilations_;
- std::vector<int32> strides_;
- Padding padding_;
- TensorFormat data_format_ = FORMAT_NHWC;
+ ConvOpAttrs attrs_;
private:
TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropInputOp);
@@ -462,172 +149,28 @@ class ConvBackpropFilterOp : public XlaOpKernel {
public:
explicit ConvBackpropFilterOp(OpKernelConstruction* ctx, int num_spatial_dims,
bool depthwise)
- : XlaOpKernel(ctx),
- num_spatial_dims_(num_spatial_dims),
- depthwise_(depthwise) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
- string data_format;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
- OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
- errors::InvalidArgument("Invalid data format"));
+ : XlaOpKernel(ctx) {
+ xla::StatusOr<ConvOpAttrs> attrs =
+ ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
+ OP_REQUIRES_OK(ctx, attrs.status());
+ attrs_ = attrs.ValueOrDie();
}
- int num_dims() const { return num_spatial_dims_ + 2; }
-
void Compile(XlaOpKernelContext* ctx) override {
- const int n_dim = GetTensorBatchDimIndex(num_dims(), data_format_);
- const int c_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
-
- OP_REQUIRES(
- ctx, (strides_[n_dim] == 1 && strides_[c_dim] == 1),
- errors::InvalidArgument("Current implementation does not yet support "
- "strides in the batch and depth dimensions."));
-
- OP_REQUIRES(ctx, dilations_.size() == num_dims(),
- errors::InvalidArgument("Dilations field must "
- "specify ",
- num_dims(), " dimensions"));
- OP_REQUIRES(
- ctx, dilations_[n_dim] == 1 && dilations_[c_dim] == 1,
- errors::Unimplemented("Current implementation does not support "
- "dilations in the batch and depth dimensions."));
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
- errors::Unimplemented("Dilation values must be positive; ", i,
- "th spatial dimension had dilation ",
- dilations_[input_dim]));
- }
-
- const TensorShape activations_shape = ctx->InputShape(0);
- TensorShape filter_shape;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape));
- const TensorShape out_backprop_shape = ctx->InputShape(2);
-
- const TensorShape expanded_filter_shape =
- depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
- : filter_shape;
-
- // Reuse dimension computation logic from conv_grad_ops.cc.
- ConvBackpropDimensions dims;
- OP_REQUIRES_OK(ctx,
- ConvBackpropComputeDimensionsV2(
- type_string(), num_spatial_dims_, activations_shape,
- expanded_filter_shape, out_backprop_shape, dilations_,
- strides_, padding_, data_format_, &dims));
-
- xla::XlaBuilder* b = ctx->builder();
- xla::XlaOp activations = ctx->Input(0);
- xla::XlaOp gradients = ctx->Input(2);
-
- // The filter gradients are computed by a convolution of the input
- // activations and the output gradients, with some appropriate padding.
- // See the comment at the top of conv_grad_ops.h for details.
-
- xla::ConvolutionDimensionNumbers dnums;
-
- // The activations (inputs) form the LHS of the convolution.
- // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
- // For the gradient computation, we flip the roles of the batch and
- // feature dimensions.
- // Each spatial entry has size in_depth * batch
-
- // Swap n_dim and c_dim in the activations.
- dnums.set_input_batch_dimension(c_dim);
- dnums.set_input_feature_dimension(n_dim);
-
- // The gradients become the RHS of the convolution.
- // The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
- // where the batch becomes the input feature for the convolution.
- dnums.set_kernel_input_feature_dimension(n_dim);
- dnums.set_kernel_output_feature_dimension(c_dim);
-
- std::vector<std::pair<int64, int64>> padding(num_spatial_dims_);
- std::vector<int64> rhs_dilation(num_spatial_dims_);
- std::vector<int64> window_strides(num_spatial_dims_);
- std::vector<int64> ones(num_spatial_dims_, 1);
-
- // Tensorflow filter shape is [ H, W, ..., inC, outC ].
- for (int i = 0; i < num_spatial_dims_; ++i) {
- dnums.add_output_spatial_dimensions(i);
- }
- dnums.set_output_batch_dimension(num_spatial_dims_);
- dnums.set_output_feature_dimension(num_spatial_dims_ + 1);
-
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- dnums.add_input_spatial_dimensions(dim);
- dnums.add_kernel_spatial_dimensions(dim);
-
- // We will also need to pad the input with zeros such that after the
- // convolution, we get the right size for the filter.
- // The padded_in_rows should be such that when we convolve this with the
- // expanded_out_rows as a filter, we should get filter_rows back.
- //
- const int64 padded_in_size =
- dims.spatial_dims[i].expanded_output_size +
- (dims.spatial_dims[i].filter_size - 1) * dilations_[dim];
-
- // However it can be smaller than input_rows: in this
- // case it means some of the inputs are not used.
- //
- // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
- //
- // INPUT = [ A B C ]
- //
- // FILTER = [ x y ]
- //
- // and the output will only have one column: a = A * x + B * y
- //
- // and input "C" is not used at all.
- //
- // We apply negative padding in this case.
- const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size;
-
- // + For the VALID padding, we don't pad anything on the top/left side
- // and pad the bottom/right side with the remaining space.
- // + For the SAME padding, we pad top/left side the same as bottom/right
- // side.
- //
- // In addition, if the padded input size is smaller than the input size,
- // we need to ignore some training elements of the input. We do this by
- // applying negative padding on the right/bottom.
- const int64 pad_before =
- padding_ == Padding::SAME ? std::max<int64>(pad_total / 2, 0) : 0;
-
- padding[i] = {pad_before, pad_total - pad_before};
- rhs_dilation[i] = dims.spatial_dims[i].stride;
- window_strides[i] = dilations_[dim];
- }
-
- // Besides padding the input, we will also expand output_rows to
- // expanded_out_rows = (output_rows - 1) * stride + 1
- // with zeros in between:
- //
- // a . . . b . . . c . . . d . . . e
- //
- // This is done by specifying the window dilation factors in the
- // convolution HLO below.
- auto filter_backprop =
- xla::ConvGeneralDilated(activations, gradients, window_strides, padding,
- /*lhs_dilation=*/ones, rhs_dilation, dnums);
-
- if (depthwise_) {
- filter_backprop = ContractFilterForDepthwiseBackprop(
- ctx, filter_shape, ctx->input_type(0), filter_backprop, b);
- }
- ctx->SetOutput(0, filter_backprop);
+ TensorShape filter_tensor_shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_tensor_shape));
+ xla::Shape filter_shape =
+ TensorShapeToXLAShape(ctx->input_xla_type(0), filter_tensor_shape);
+
+ xla::StatusOr<xla::XlaOp> filter_backprop = MakeXlaBackpropFilterConvOp(
+ ctx->op_kernel().type_string(), ctx->Input(0), filter_shape,
+ ctx->Input(2), attrs_);
+ OP_REQUIRES_OK(ctx, filter_backprop.status());
+ ctx->SetOutput(0, filter_backprop.ValueOrDie());
}
protected:
- const int num_spatial_dims_;
- const bool depthwise_;
- std::vector<int32> dilations_;
- std::vector<int32> strides_;
- Padding padding_;
- TensorFormat data_format_ = FORMAT_NHWC;
+ ConvOpAttrs attrs_;
private:
TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropFilterOp);
diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
index 33a73fe5fd..921b4340c0 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
@@ -355,6 +355,9 @@ class NonMaxSuppressionOp : public XlaOpKernel {
OP_REQUIRES(
context, output_size >= 0,
errors::InvalidArgument("Need output_size >= 0, got ", output_size));
+ OP_REQUIRES(context, output_size <= kint32max,
+ errors::InvalidArgument("Need output_size <= kint32Max, got ",
+ output_size));
xla::XlaOp score_thresh = context->Input("score_threshold");
xla::XlaOp iou_thresh = context->Input("iou_threshold");
@@ -439,12 +442,14 @@ class NonMaxSuppressionOp : public XlaOpKernel {
xla::Broadcast(xla::ConstantR0<int32>(builder, 1), {num_boxes}),
xla::Broadcast(xla::ConstantR0<int32>(builder, 0), {num_boxes}));
- // num_valid is scalar.
- xla::XlaOp num_valid = xla::Reduce(
+ // num_valid is scalar. Value should be bound by output_size.
+ xla::XlaOp num_valid_total = xla::Reduce(
ones_included,
/*init_value=*/xla::ConstantR0<int>(builder, 0),
/*computation=*/CreateScalarAddComputation(xla::S32, builder),
/*dimensions_to_reduce=*/{0});
+ xla::XlaOp num_valid =
+ xla::Min(num_valid_total, xla::ConstantR0<int32>(builder, output_size));
xla::XlaOp output_tuple = TopK(scores_included, output_size);
xla::XlaOp selected_indices = xla::GetTupleElement(output_tuple, 1);
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index d9a0257b70..7b2bb4a7c5 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@@ -132,14 +133,14 @@ int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size,
// If the 2D kernel would be very large, the 1D kernel can be applied once in
// each dimension due to the symmetry of the kernel along all axis to reduce the
// computational intensity.
-std::vector<float> Make1DKernel(int64 n) {
+xla::XlaOp Make1DKernel(xla::XlaBuilder* builder, int64 n) {
std::vector<float> kernel(n * 2 - 1);
for (int64 i = 0; i < n; ++i) {
float v = (i + 1.0f) / n;
kernel[i] = v;
kernel[n * 2 - 2 - i] = v;
}
- return kernel;
+ return xla::ConstantR1<float>(builder, kernel);
}
// Kernels with more than 16 spatial elements are considered intense and the
@@ -149,41 +150,26 @@ const int64 kMax2DKernelSize = 16;
xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder,
absl::Span<const int64> kernel_size,
int64 channels) {
- xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
+ auto depthwise_kernel = xla::Broadcast(
+ xla::Zero(builder, xla::F32),
+ {(2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1});
- auto diag = xla::ConvertElementType(
- xla::Eq(xla::Broadcast(channels_iota, {2 * kernel_size[0] - 1,
- 2 * kernel_size[1] - 1, channels}),
- channels_iota, /*broadcast_dimensions=*/{2}),
- xla::PrimitiveType::F32);
return xla::Mul(
- xla::Mul(diag,
- xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])),
+ xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[1]),
/*broadcast_dimensions=*/{1}),
- xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])),
+ Make1DKernel(builder, kernel_size[0]),
/*broadcast_dimensions=*/{0});
}
xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder,
absl::Span<const int64> kernel_size,
int64 channels, int64 dim) {
- xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
-
- auto diag = xla::ConvertElementType(
- xla::Eq(
- xla::Broadcast(channels_iota,
- {dim == 0 ? (2 * kernel_size[0] - 1) : 1,
- dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}),
- channels_iota, /*broadcast_dimensions=*/{2}),
- xla::PrimitiveType::F32);
- if (dim == 1) {
- return xla::Mul(
- diag, xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])),
- /*broadcast_dimensions=*/{1});
- }
- return xla::Mul(diag,
- xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])),
- /*broadcast_dimensions=*/{0});
+ auto depthwise_kernel =
+ xla::Broadcast(xla::Zero(builder, xla::F32),
+ {dim == 0 ? (2 * kernel_size[0] - 1) : 1,
+ dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1});
+ return xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[dim]),
+ /*broadcast_dimensions=*/{dim});
}
xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
@@ -206,8 +192,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
xla::ConvolutionDimensionNumbers dimension_numbers;
dimension_numbers.set_input_batch_dimension(0);
dimension_numbers.set_output_batch_dimension(0);
- dimension_numbers.set_input_feature_dimension(3);
- dimension_numbers.set_output_feature_dimension(3);
+ dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
for (int i = 0; i < num_spatial_dims; ++i) {
dimension_numbers.add_input_spatial_dimensions(1 + i);
dimension_numbers.add_output_spatial_dimensions(1 + i);
@@ -285,7 +271,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
{{dims.kernel_size[0] - 1, upper_padding[0]},
{dims.kernel_size[1] - 1, upper_padding[1]}},
/*lhs_dilation=*/dims.kernel_size,
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
} else {
xla::XlaOp kernel0 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0);
@@ -294,7 +281,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
/*padding=*/
{{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}},
/*lhs_dilation=*/{dims.kernel_size[0], 1},
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
xla::XlaOp kernel1 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1);
output = xla::ConvGeneralDilated(
@@ -302,7 +290,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
/*padding=*/
{{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}},
/*lhs_dilation=*/{1, dims.kernel_size[1]},
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
}
// Add broadcasts to handle expanding from a size == 1 dimension to a
@@ -331,15 +320,15 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
xla::ConvolutionDimensionNumbers dimension_numbers;
dimension_numbers.set_input_batch_dimension(0);
dimension_numbers.set_output_batch_dimension(0);
- dimension_numbers.set_input_feature_dimension(3);
- dimension_numbers.set_output_feature_dimension(3);
+ dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
for (int i = 0; i < num_spatial_dims; ++i) {
- dimension_numbers.add_input_spatial_dimensions(1 + i);
- dimension_numbers.add_output_spatial_dimensions(1 + i);
+ dimension_numbers.add_input_spatial_dimensions(i + 1);
+ dimension_numbers.add_output_spatial_dimensions(i + 1);
dimension_numbers.add_kernel_spatial_dimensions(i);
}
- dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims);
- dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
xla::XlaOp output;
if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
xla::XlaOp kernel =
@@ -362,7 +351,8 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
{dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
/*lhs_dilation=*/dims.stride,
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
} else {
xla::XlaOp kernel0 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0);
@@ -388,14 +378,16 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
/*padding=*/
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
/*lhs_dilation=*/{dims.stride[0], 1},
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
output = xla::ConvGeneralDilated(
output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]},
/*padding=*/
{{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
/*lhs_dilation=*/{1, dims.stride[1]},
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
}
// If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i.
diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
index 2e0a69b70e..c8a0f31a03 100644
--- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
@@ -44,7 +44,7 @@ class ShapeOp : public XlaOpKernel {
DataType out_dtype_;
};
-REGISTER_XLA_OP(Name("Shape").CompilationOnly(), ShapeOp);
+REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp);
class ShapeNOp : public XlaOpKernel {
public:
@@ -66,7 +66,7 @@ class ShapeNOp : public XlaOpKernel {
private:
DataType out_dtype_;
};
-REGISTER_XLA_OP(Name("ShapeN").CompilationOnly(), ShapeNOp);
+REGISTER_XLA_OP(Name("ShapeN").CompilationOnly().IsMetadataOp(), ShapeNOp);
class RankOp : public XlaOpKernel {
public:
@@ -82,7 +82,7 @@ class RankOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Rank").CompilationOnly(), RankOp);
+REGISTER_XLA_OP(Name("Rank").CompilationOnly().IsMetadataOp(), RankOp);
class SizeOp : public XlaOpKernel {
public:
@@ -101,7 +101,7 @@ class SizeOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Size").CompilationOnly(), SizeOp);
+REGISTER_XLA_OP(Name("Size").CompilationOnly().IsMetadataOp(), SizeOp);
class ExpandDimsOp : public XlaOpKernel {
public:
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index 02363500ef..733eeed3c6 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -121,8 +121,8 @@ Wraps the XLA DynamicSlice operator, documented at
DynamicSlice extracts a sub-array from the input array at dynamic
start_indices. The size of the slice in each dimension is passed in
size_indices, which specify the end point of exclusive slice intervals in each
-dimension -- [start, start + size). The shape of start_indices must be rank ==
-1, with dimension size equal to the rank of operand.
+dimension -- [start, start + size). The shape of start_indices must have rank 1,
+with dimension size equal to the rank of operand.
input: A `Tensor` of type T.
@@ -131,7 +131,8 @@ start_indices: Rank 1 tensor of N integers containing the starting indices of
start_indices: List of N integers containing the slice size for each
dimension. Each value must be strictly greater than zero, and start + size
- must be less
+ must be less than or equal to the size of the dimension to avoid
+ implementation defined behavior.
)doc");
REGISTER_OP("XlaDynamicUpdateSlice")
diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc
index 9d1992205b..b589512dcd 100644
--- a/tensorflow/compiler/tf2xla/shape_util.cc
+++ b/tensorflow/compiler/tf2xla/shape_util.cc
@@ -41,6 +41,14 @@ Status XLAShapeToTensorShape(const xla::Shape& shape,
// Convert a TensorShape into the equivalent XLA Shape proto.
Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
xla::Shape* shape) {
+ xla::PrimitiveType type;
+ TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
+ *shape = TensorShapeToXLAShape(type, tensor_shape);
+ return Status::OK();
+}
+
+xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
+ const TensorShape& tensor_shape) {
int rank = tensor_shape.dims();
std::vector<int64> dimensions(rank);
std::vector<int64> layout(rank);
@@ -50,11 +58,7 @@ Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
// XLA uses minor-to-major; Tensorflow uses major-to-minor.
std::iota(layout.rbegin(), layout.rend(), 0);
- xla::PrimitiveType type;
- TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
-
- *shape = xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
- return Status::OK();
+ return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h
index 58240b9c96..f7e34a5b40 100644
--- a/tensorflow/compiler/tf2xla/shape_util.h
+++ b/tensorflow/compiler/tf2xla/shape_util.h
@@ -35,6 +35,11 @@ Status XLAShapeToTensorShape(const xla::Shape& shape,
Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
xla::Shape* shape);
+// Converts a TensorShape into the equivalent XLA Shape proto, taking an
+// xla::PrimitiveType to specify the element type. This never fails.
+xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
+ const TensorShape& tensor_shape);
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/type_util.h b/tensorflow/compiler/tf2xla/type_util.h
index bda667eb1f..6354216eee 100644
--- a/tensorflow/compiler/tf2xla/type_util.h
+++ b/tensorflow/compiler/tf2xla/type_util.h
@@ -25,6 +25,14 @@ namespace tensorflow {
// Converts a Tensorflow DataType to an XLA PrimitiveType.
Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type);
+// N.B.: there is intentionally no function to convert an XLA PrimitiveType to
+// a TensorFlow DataType. The mapping from TF types to XLA types is not
+// one-to-one: for example, both DT_INT8 and DT_QINT8 map to xla::S8. So the
+// inverse would not be a well-defined function. If you find that you want the
+// inverse mapping, then most likely you should be preserving the original
+// TensorFlow type, rather than trying to convert an XLA type into a TensorFlow
+// type.
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 739e47778a..d5094e8ec5 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -333,10 +333,8 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
}
// Builds the XLA computation.
-//
-// `retvals` is the list of retvals produced by _Retval operators, in index
-// order. `variable_map` is a map from variable ID numbers to XlaOpContext
-// variable states, generated by the symbolic evaluation.
+// `args` is the list of input arguments, `retvals` is the list of retvals
+// produced by _Retval operators, in index order.
// If `return_updated_values_for_all_resources` is true, all resources will be
// included in `resource_updates`, regardless of whether their value changed.
// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
diff --git a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
index 23d04d43b3..bc44301d40 100644
--- a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
+++ b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
@@ -20,21 +20,6 @@ limitations under the License.
namespace tensorflow {
bool CpuOpFilter(KernelDef* kdef) {
- // TODO(b/34339814): implement inverse erf for double types and remove this
- // workaround.
- if (kdef->op() == "RandomStandardNormal") {
- kdef->clear_constraint();
- // Change the type constraint to permit only DTD_FLOAT.
- KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint();
- attr_constraint->set_name("dtype");
- attr_constraint->mutable_allowed_values()->mutable_list()->add_type(
- DT_FLOAT);
- return true;
- }
- // TODO(b/26783907): The CPU backend currently does not implement sort.
- if (kdef->op() == "XlaSort" || kdef->op() == "TopKV2") {
- return false;
- }
if (kdef->op() == "Const") {
AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef);
}
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc
index b0eeee3174..91d48125f1 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc
@@ -90,6 +90,11 @@ XlaOpRegistry::~XlaOpRegistry() = default;
<< " have incompatible compile time constant inputs.";
return false;
}
+ if (x.is_metadata_op != y.is_metadata_op) {
+ LOG(WARNING) << "Registrations of " << x.name
+ << " have incompatible values for is_metadata_op.";
+ return false;
+ }
return true;
}
@@ -350,6 +355,20 @@ XlaOpRegistry::CompileTimeConstantInputs(const string& op) {
return &it->second.front()->compile_time_constant_inputs;
}
+/*static*/ bool XlaOpRegistry::IsMetadataOp(const string& op) {
+ XlaOpRegistry& registry = Instance();
+ mutex_lock lock(registry.mutex_);
+ auto it = registry.ops_.find(op);
+ if (it == registry.ops_.end() || it->second.empty()) {
+ return false;
+ }
+
+ // The test in IsCompatible ensures that if there are multiple matching
+ // registrations for this op name, they all have the same value of
+ // is_metadata_op, so only the first match is returned.
+ return it->second.front()->is_metadata_op;
+}
+
std::vector<string> XlaOpRegistry::BackendNames() {
std::vector<string> names;
XlaOpRegistry& registry = Instance();
@@ -432,6 +451,11 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput(
return *this;
}
+XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::IsMetadataOp() {
+ registration_->is_metadata_op = true;
+ return *this;
+}
+
std::unique_ptr<XlaOpRegistry::OpRegistration> XlaOpRegistrationBuilder::Build(
XlaOpRegistry::Factory factory) {
registration_->factory = factory;
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index 74a4885f1f..4b2c2bacd6 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -47,17 +47,18 @@ extern const char* const DEVICE_XLA_GPU;
constexpr std::array<DataType, 4> kFloatTypes = {
{DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}};
-constexpr std::array<DataType, 9> kNumericTypes = {
- {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_BFLOAT16}};
+constexpr std::array<DataType, 11> kNumericTypes = {
+ {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF,
+ DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}};
-constexpr std::array<DataType, 9> kCpuAllTypes = {
- {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_BOOL}};
+constexpr std::array<DataType, 14> kCpuAllTypes = {
+ {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
+ DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
-constexpr std::array<DataType, 10> kGpuAllTypes = {
- {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
+constexpr std::array<DataType, 15> kGpuAllTypes = {
+ {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
+ DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL,
+ DT_BFLOAT16}};
// Class that manages registrations of operators and devices for the XLA JIT.
// Not thread-safe.
@@ -136,6 +137,10 @@ class XlaOpRegistry {
static const std::unordered_set<string>* CompileTimeConstantInputs(
const string& op);
+ // Returns true if `op` is a "metadata" op, one that only looks at the shapes
+ // of its operands and not their values.
+ static bool IsMetadataOp(const string& op);
+
private:
friend class XlaBackendRegistrar;
friend class XlaOpRegistrar;
@@ -192,6 +197,10 @@ class XlaOpRegistry {
// Names of arguments that must be compile-time constants.
std::unordered_set<string> compile_time_constant_inputs;
+ // True if this is a "metadata" op, one that only looks at the shapes of its
+ // operands and not their values.
+ bool is_metadata_op = false;
+
// Factory used to build OpKernels that perform symbolic execution.
Factory factory;
};
@@ -256,6 +265,10 @@ class XlaOpRegistrationBuilder {
// Mark 'input_name' as an argument whose value must be known at compile-time.
XlaOpRegistrationBuilder& CompileTimeConstInput(absl::string_view input_name);
+ // Mark this op as a "metadata" op, one that only looks at the shapes of its
+ // operands and not their values.
+ XlaOpRegistrationBuilder& IsMetadataOp();
+
std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
XlaOpRegistry::Factory factory);
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index ef70c1f8ac..cc7390c6e6 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -245,6 +245,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc
index 25cc37edc4..ff0ec76a7f 100644
--- a/tensorflow/compiler/xla/client/lib/testing.cc
+++ b/tensorflow/compiler/xla/client/lib/testing.cc
@@ -97,13 +97,11 @@ std::vector<std::unique_ptr<GlobalData>> MakeFakeArgumentsOrDie(
<< "Computation should have progran shape.";
auto program_shape = computation.proto().program_shape();
- // Create and run a program which produces a tuple with one element per
- // parameter, then return the tuple's constituent buffers.
- std::vector<Shape> param_shapes(program_shape.parameters().begin(),
- program_shape.parameters().end());
- auto fake_input_tuple =
- MakeFakeDataOrDie(ShapeUtil::MakeTupleShape(param_shapes), client);
- return client->DeconstructTuple(*fake_input_tuple).ValueOrDie();
+ std::vector<std::unique_ptr<GlobalData>> results;
+ for (const Shape& shape : program_shape.parameters()) {
+ results.push_back(MakeFakeDataOrDie(shape, client));
+ }
+ return results;
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 95ff6432a5..5277de6a85 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1278,7 +1278,7 @@ XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
XlaOp XlaBuilder::CustomCall(const string& call_target_name,
absl::Span<const XlaOp> operands,
- const Shape& shape) {
+ const Shape& shape, const string& opaque) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (absl::StartsWith(call_target_name, "$")) {
@@ -1289,6 +1289,7 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name,
}
*instr.mutable_shape() = shape;
instr.set_custom_call_target(call_target_name);
+ instr.set_custom_call_opaque(opaque);
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
});
}
@@ -2681,8 +2682,9 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
}
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape) {
- return builder->CustomCall(call_target_name, operands, shape);
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ const string& opaque) {
+ return builder->CustomCall(call_target_name, operands, shape, opaque);
}
XlaOp Complex(const XlaOp& real, const XlaOp& imag,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index d0c59fa6f2..1da6ddd318 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -577,11 +577,9 @@ class XlaBuilder {
absl::Span<const XlaOp> operands);
// Enqueues a custom call instruction onto the computation.
- // During code generation, a call instruction is emitted which targets a
- // symbol with the name |call_target_name|. The |operands| are passed to the
- // call instruction. |shape| is the resultant shape.
XlaOp CustomCall(const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape);
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ const string& opaque);
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
@@ -1195,7 +1193,8 @@ class XlaBuilder {
friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
absl::Span<const XlaOp> operands);
friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape);
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ const string& opaque);
friend XlaOp Complex(const XlaOp& real, const XlaOp& imag,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Conj(const XlaOp& operand);
@@ -1717,12 +1716,17 @@ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
absl::Span<const XlaOp> operands);
-// Enqueues a custom call instruction onto the computation.
-// During code generation, a call instruction is emitted which targets a
-// symbol with the name |call_target_name|. The |operands| are passed to the
-// call instruction. |shape| is the resultant shape.
+// Enqueues a custom call instruction onto the computation. A custom call
+// invokes code external to XLA. The |operands| are passed to the external code,
+// and the external code is expected to produce a result of the given
+// |shape|. The exact mechanism is backend-specific. For example, in the CPU
+// backend, a call instruction is emitted which targets a symbol with the name
+// |call_target_name|. |call_target_name| and |opaque| can arbitrary strings,
+// but |call_target_name| should be short as it may be used in labels. |opaque|
+// can encode arbitrarily large amounts of information.
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape);
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ const string& opaque = "");
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc
index a472747bd1..0f9b591c70 100644
--- a/tensorflow/compiler/xla/executable_run_options.cc
+++ b/tensorflow/compiler/xla/executable_run_options.cc
@@ -45,6 +45,16 @@ stream_executor::Stream* ExecutableRunOptions::stream() const {
return stream_;
}
+ExecutableRunOptions& ExecutableRunOptions::set_host_to_device_stream(
+ stream_executor::Stream* stream) {
+ host_to_device_stream_ = stream;
+ return *this;
+}
+
+stream_executor::Stream* ExecutableRunOptions::host_to_device_stream() const {
+ return host_to_device_stream_;
+}
+
ExecutableRunOptions& ExecutableRunOptions::set_intra_op_thread_pool(
const Eigen::ThreadPoolDevice* intra_op_thread_pool) {
intra_op_thread_pool_ = intra_op_thread_pool;
diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h
index 416131be00..ba3217f31b 100644
--- a/tensorflow/compiler/xla/executable_run_options.h
+++ b/tensorflow/compiler/xla/executable_run_options.h
@@ -65,6 +65,13 @@ class ExecutableRunOptions {
ExecutableRunOptions& set_stream(stream_executor::Stream* stream);
stream_executor::Stream* stream() const;
+ // If set, this is the stream to perform any pre-computation transfers on.
+ // The platform of the stream must match the platform the executable was
+ // built for. A value of nullptr indicates the option has not been set.
+ ExecutableRunOptions& set_host_to_device_stream(
+ stream_executor::Stream* stream);
+ stream_executor::Stream* host_to_device_stream() const;
+
// Sets the thread pool device on which to run Eigen subcomputations.
// Does not take ownership.
ExecutableRunOptions& set_intra_op_thread_pool(
@@ -90,6 +97,7 @@ class ExecutableRunOptions {
const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
ExecutionProfile* execution_profile_ = nullptr;
int rng_seed_ = 0;
+ stream_executor::Stream* host_to_device_stream_ = nullptr;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
index 0d3136b0cc..3ed3afcfce 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
@@ -57,6 +57,8 @@ void SetDebugOptionsDefaults(DebugOptions* flags) {
// regression.
flags->set_xla_cpu_enable_fast_math(true);
flags->set_xla_gpu_enable_fast_math(true);
+
+ flags->set_xla_force_host_platform_device_count(1);
}
// Allocates flag_values and flag_objects; this function must not be called more
@@ -323,6 +325,17 @@ void AllocateFlags() {
flag_values->xla_gpu_crash_on_verification_failures(),
"Crashes the program on extra verification failures, e.g. cuDNN "
"cross checking failures"),
+ tensorflow::Flag(
+ "xla_force_host_platform_device_count",
+ int32_setter_for(
+ &DebugOptions::set_xla_force_host_platform_device_count),
+ flag_values->xla_force_host_platform_device_count(),
+ "Force the host platform to pretend that there are these many "
+ "host \"devices\". All of these host devices are backed by the same"
+ "threadpool. Setting this to anything other than 1 can increase "
+ "overhead from context switching but we let the user override this "
+ "behavior to help run tests on the host that run models in parallel "
+ "across multiple devices."),
});
ParseFlagsFromEnv(*flag_objects);
}
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index 1e0a2ad0dd..3cd3541fe1 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -203,6 +203,10 @@ class LiteralBase {
// Returns the count of the elements in the array at the given shape index in
// this literal.
int64 element_count(const ShapeIndex& index = {}) const {
+ if (index.empty()) {
+ // Common case, avoid GetSubshape().
+ return ShapeUtil::ElementsIn(shape());
+ }
return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
}
@@ -852,9 +856,9 @@ class BorrowingLiteral : public LiteralBase {
template <typename NativeT>
absl::Span<const NativeT> LiteralBase::Piece::data() const {
- CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
- CHECK_EQ(subshape().element_type(),
- primitive_util::NativeToPrimitiveType<NativeT>())
+ DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+ DCHECK_EQ(subshape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>())
<< "Attempting to access "
<< PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
<< " type, but literal element type is "
@@ -865,9 +869,9 @@ absl::Span<const NativeT> LiteralBase::Piece::data() const {
template <typename NativeT>
absl::Span<NativeT> LiteralBase::Piece::data() {
- CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
- CHECK_EQ(subshape().element_type(),
- primitive_util::NativeToPrimitiveType<NativeT>())
+ DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+ DCHECK_EQ(subshape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>())
<< "Attempting to access "
<< PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
<< " type, but literal element type is "
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index 9da5dc0d2d..cd5fd33029 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -469,9 +469,11 @@ LocalOp LocalComputationBuilder::ConvGeneralDilated(
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers) {
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count) {
return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding,
- lhs_dilation, rhs_dilation, dimension_numbers);
+ lhs_dilation, rhs_dilation, dimension_numbers,
+ feature_group_count);
}
LocalOp LocalComputationBuilder::ConvertElementType(
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 1d5dfe5911..2166bb6721 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -248,7 +248,8 @@ class LocalComputationBuilder {
absl::Span<const std::pair<int64, int64> > padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count);
LocalOp ConvertElementType(const LocalOp& operand,
PrimitiveType new_element_type);
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index fa4366ff07..bb303c5678 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -1109,7 +1109,7 @@ class ComputationBuilder(object):
dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
return self._client.DotGeneral(lhs, rhs, dimension_numbers)
- def Conv(self, lhs, rhs, window_strides, padding):
+ def Conv(self, lhs, rhs, window_strides, padding, feature_group_count=1):
"""Enqueues a Conv operation onto the computation.
Args:
@@ -1117,6 +1117,7 @@ class ComputationBuilder(object):
rhs: LocalOp for the rank N+2 array of kernel weights.
window_strides: length-N array-like of integer kernel strides.
padding: PaddingType representing either 'SAME' or 'VALID' padding.
+ feature_group_count: number of feature groups for grouped convolution.
Returns: a LocalOp representing the Conv operation.
"""
@@ -1125,10 +1126,11 @@ class ComputationBuilder(object):
self.GetShape(rhs).dimensions()[2:], window_strides)
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (),
- (), dimension_numbers)
+ (), dimension_numbers,
+ feature_group_count)
def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding,
- lhs_dilation, rhs_dilation):
+ lhs_dilation, rhs_dilation, feature_group_count=1):
"""Enqueues a ConvWithGeneralPadding operation onto the computation.
Args:
@@ -1138,6 +1140,7 @@ class ComputationBuilder(object):
padding: length-N array-like of pairs of integers of (low, high) padding.
lhs_dilation: length-N array-like of dilation factors.
rhs_dilation: length-N array-like of dilation factors.
+ feature_group_count: number of feature groups for grouped convolution.
Returns:
A ComputationdataHandle representing the added ConvWithGeneralPadding op.
@@ -1145,7 +1148,8 @@ class ComputationBuilder(object):
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation,
- dimension_numbers)
+ dimension_numbers,
+ feature_group_count)
def _GetConvDimensionNumbers(self, num_spatial_dims):
"""Create ConvolutionDimensionNumbers proto for convolutions."""
@@ -1163,7 +1167,8 @@ class ComputationBuilder(object):
return dimension_numbers
def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation,
- rhs_dilation, dimension_numbers):
+ rhs_dilation, dimension_numbers,
+ feature_group_count=1):
"""Enqueues a ConvGeneralDilated operation onto the computation.
Args:
@@ -1190,6 +1195,7 @@ class ComputationBuilder(object):
labels appear in the rhs_spec string, so that window_strides[0] is
matched with the dimension corresponding to the first character
appearing in rhs_spec that is not 'I' or 'O'.
+ feature_group_count: number of feature groups for grouped convolution.
Returns: a LocalOp representing the ConvGenralDilated operation.
"""
@@ -1215,7 +1221,8 @@ class ComputationBuilder(object):
key=lambda i: rhs_spec.index(out_spec[i])))
return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation,
- dimension_numbers)
+ dimension_numbers,
+ feature_group_count)
def Sort(self, operand, dimension=-1):
"""Enqueues a sort operation onto the computation."""
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index fd98e19457..82103f0313 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -661,6 +661,30 @@ class SingleOpTest(LocalComputationTest):
[40., 50., 0.]]]])
self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2)))
+ def testConvGeneralDilatedGroupedConvolutionF32(self):
+ c = self._NewComputation()
+ a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
+ lhs = a(1, 2, 2, 3)
+ rhs = a(2, 1, 1, 2) * 10
+ strides = [1, 1]
+ pads = [(1, 0), (0, 1)]
+ lhs_dilation = (2, 1)
+ rhs_dilation = (1, 1)
+ dimension_numbers = ("NCHW", "OIHW", "NCHW")
+ feature_group_count = 2
+ c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs),
+ strides, pads, lhs_dilation, rhs_dilation,
+ dimension_numbers, feature_group_count)
+ result = np.array([[[[0., 0., 0.],
+ [10., 20., 0.],
+ [0., 0., 0.],
+ [40., 50., 0.]],
+ [[0., 0., 0.],
+ [330., 380., 160.],
+ [0., 0., 0.],
+ [480., 530., 220.]]]])
+ self._ExecuteAndCompareClose(c, expected=result)
+
def testBooleanNot(self):
c = self._NewComputation()
arr = NumpyArrayBool([True, False, True])
diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD
index 97fcd37f6b..3abb3855a4 100644
--- a/tensorflow/compiler/xla/rpc/BUILD
+++ b/tensorflow/compiler/xla/rpc/BUILD
@@ -34,19 +34,28 @@ cc_library(
],
)
-tf_cc_binary(
- name = "grpc_service_main_cpu",
+cc_library(
+ name = "grpc_service_main_library",
srcs = ["grpc_service_main.cc"],
deps = [
":grpc_service",
"//tensorflow:grpc++",
"//tensorflow/compiler/xla/service:cpu_plugin",
+ "//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings:str_format",
],
)
+tf_cc_binary(
+ name = "grpc_service_main_cpu",
+ deps = [
+ ":grpc_service_main_library",
+ "//tensorflow/compiler/xla/service:cpu_plugin",
+ ],
+)
+
tf_cc_test(
name = "grpc_client_test",
srcs = ["grpc_client_test.cc"],
diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc
index d6b5149a24..522ab99fb1 100644
--- a/tensorflow/compiler/xla/rpc/grpc_service_main.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "grpcpp/server_builder.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/rpc/grpc_service.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"
@@ -29,8 +30,15 @@ namespace {
int RealMain(int argc, char** argv) {
int32 port = 1685;
+ bool any_address = false;
+ string platform_str;
std::vector<tensorflow::Flag> flag_list = {
- tensorflow::Flag("port", &port, "port to listen on"),
+ tensorflow::Flag("platform", &platform_str,
+ "The XLA platform this service should be bound to"),
+ tensorflow::Flag("port", &port, "The TCP port to listen on"),
+ tensorflow::Flag(
+ "any", &any_address,
+ "Whether to listen to any host address or simply localhost"),
};
string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parsed_values_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
@@ -40,19 +48,24 @@ int RealMain(int argc, char** argv) {
}
tensorflow::port::InitMain(argv[0], &argc, &argv);
+ se::Platform* platform = nullptr;
+ if (!platform_str.empty()) {
+ platform = PlatformUtil::GetPlatform(platform_str).ValueOrDie();
+ }
std::unique_ptr<xla::GRPCService> service =
- xla::GRPCService::NewService().ConsumeValueOrDie();
+ xla::GRPCService::NewService(platform).ConsumeValueOrDie();
::grpc::ServerBuilder builder;
- string server_address(absl::StrFormat("localhost:%d", port));
+ string server_address(
+ absl::StrFormat("%s:%d", any_address ? "[::]" : "localhost", port));
+ builder.SetMaxReceiveMessageSize(INT_MAX);
builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials());
builder.RegisterService(service.get());
std::unique_ptr<::grpc::Server> server(builder.BuildAndStart());
LOG(INFO) << "Server listening on " << server_address;
server->Wait();
-
return 0;
}
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index fb80c78f68..e800cf470c 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -365,8 +365,11 @@ cc_library(
hdrs = ["pattern_matcher.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
+ "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/utility",
],
)
@@ -590,6 +593,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/strings",
@@ -1166,6 +1170,7 @@ tf_cc_test(
":hlo",
":hlo_matchers",
":hlo_module_group",
+ ":hlo_module_group_metadata",
":hlo_parser",
":hlo_proto",
"//tensorflow/compiler/xla:test",
@@ -2557,6 +2562,7 @@ cc_library(
],
deps = [
":hlo",
+ ":hlo_module_group",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
@@ -2588,6 +2594,26 @@ cc_library(
],
)
+tf_cc_test(
+ name = "hlo_pass_pipeline_test",
+ srcs = ["hlo_pass_pipeline_test.cc"],
+ deps = [
+ ":hlo",
+ ":hlo_parser",
+ ":hlo_pass_pipeline",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
cc_library(
name = "hlo_cse",
srcs = ["hlo_cse.cc"],
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 4ef1dffa73..75dae7a714 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -754,11 +754,12 @@ StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
};
auto reshape_if_necessary = [&](HloInstruction* hlo) {
+ hlo = as_type(hlo, dot->shape().element_type());
if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) {
hlo = computation_->AddInstruction(
HloInstruction::CreateReshape(dot->shape(), hlo));
}
- return as_type(hlo, dot->shape().element_type());
+ return hlo;
};
auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) {
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h
index b864c372fa..9f8d0ee88b 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.h
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h
@@ -24,7 +24,7 @@ limitations under the License.
namespace xla {
// A pass which performs algebraic simplifications.
-class AlgebraicSimplifier : public HloPassInterface {
+class AlgebraicSimplifier : public HloModulePass {
public:
// Given shapes 'from_shape' and 'to_shape', determines if it is valid to
// bitcast from 'from_shape' to 'to_shape' after considering platform
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 3fc1ba2427..2047f894b4 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -3233,17 +3233,18 @@ INSTANTIATE_TEST_CASE_P(
class DotStrengthReductionTest
: public AlgebraicSimplifierTest,
public ::testing::WithParamInterface<
- ::testing::tuple<int, int, int, bool, bool>> {};
+ ::testing::tuple<int, int, int, bool, bool, PrimitiveType>> {};
TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
int m, k, n;
bool transpose_lhs, transpose_rhs;
- std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam();
+ PrimitiveType element_type;
+ std::tie(m, k, n, transpose_lhs, transpose_rhs, element_type) = GetParam();
- Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
- Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
- Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m});
- Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
- Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k});
+ Shape dot_shape = ShapeUtil::MakeShape(element_type, {m, n});
+ Shape lhs_shape = ShapeUtil::MakeShape(element_type, {m, k});
+ Shape transposed_lhs_shape = ShapeUtil::MakeShape(element_type, {k, m});
+ Shape rhs_shape = ShapeUtil::MakeShape(element_type, {k, n});
+ Shape transposed_rhs_shape = ShapeUtil::MakeShape(element_type, {n, k});
HloComputation::Builder builder(TestName());
auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
@@ -3285,7 +3286,7 @@ INSTANTIATE_TEST_CASE_P(
DotStrengthReductionTestInstantiation, DotStrengthReductionTest,
::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2),
::testing::Values(1, 2), ::testing::Bool(),
- ::testing::Bool()));
+ ::testing::Bool(), ::testing::Values(F32, BF16)));
struct DotOfConcatTestSpec {
int64 m;
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h
index 79d37f08d3..5b625bf3b9 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification.h
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h
@@ -25,7 +25,7 @@ namespace xla {
// Normally these would live in the algebraic simplifier, but we want to run
// this to fixpoint (this pass reaches fixed point in one execution) before we
// run the DotDecomposer.
-class BatchDotSimplification : public HloPassInterface {
+class BatchDotSimplification : public HloModulePass {
public:
StatusOr<bool> Run(HloModule* module) override;
absl::string_view name() const override;
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h
index 76e32174f3..147f3ae7b6 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.h
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.h
@@ -26,7 +26,7 @@ namespace xla {
// A pass which rewrites batch norm operations into more operations. Breaking a
// big operation into smaller operations helps leverage our generic fusion
// logic.
-class BatchNormExpander : public HloPassInterface {
+class BatchNormExpander : public HloModulePass {
public:
// When use_fusion is set, a multi-output fusion node is created.
BatchNormExpander(bool rewrite_training_op = false,
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
index 5dcd31b83d..cb3d12f0bf 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
@@ -31,7 +31,7 @@ namespace xla {
// optimization pipeline followed by a DCE pass. If other passes are needed
// after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the
// changed made by this pass.
-class BFloat16ConversionFolding : public HloPassInterface {
+class BFloat16ConversionFolding : public HloModulePass {
public:
explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support)
: bfloat16_support_(bfloat16_support) {}
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h
index 30b6346312..f48e925823 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization.h
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h
@@ -25,7 +25,7 @@ namespace xla {
// A pass which adds F32 <-> BF16 conversions for HLO instructions that do not
// support BF16 input/output or mixed precision, according to the passed-in
// backend-specific BF16 support rules.
-class BFloat16Normalization : public HloPassInterface {
+class BFloat16Normalization : public HloModulePass {
public:
explicit BFloat16Normalization(const BFloat16Support* bfloat16_support)
: bfloat16_support_(bfloat16_support) {}
@@ -48,7 +48,7 @@ class BFloat16Normalization : public HloPassInterface {
// use mixed precision; it removes mixed precision even if the backend supports
// it. This pass is used to make the HLO module valid for other HLO passes which
// do not support mixed precision.
-class BFloat16MixedPrecisionRemoval : public HloPassInterface {
+class BFloat16MixedPrecisionRemoval : public HloModulePass {
public:
BFloat16MixedPrecisionRemoval() {}
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h
index 1ee64971ab..6a62439f88 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.h
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h
@@ -58,7 +58,7 @@ namespace xla {
// BFloat16ConversionFolding. If other passes are needed after this pass, run
// BFloat16MixedPrecisionRemoval first to undo some of the changes made by this
// pass.
-class BFloat16Propagation : public HloPassInterface {
+class BFloat16Propagation : public HloModulePass {
public:
explicit BFloat16Propagation(const BFloat16Support* bfloat16_support);
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 65fa951afe..34a7be0e9c 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -1064,6 +1064,19 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
// that seems to give the best results is lazy-best-fit, with all runs of
// alloc / free calls sorted in decreasing size order.
const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering();
+
+ // Returns a heap algorithm that chooses the best result from several
+ // algorithms.
+ auto get_heap_algorithm = [&](int64 alignment) {
+ auto algorithms =
+ absl::make_unique<std::vector<std::unique_ptr<HeapAlgorithm>>>();
+ algorithms->push_back(absl::make_unique<DecreasingSizeRunsHeap>(
+ absl::make_unique<LazyBestFitHeap>(alignment)));
+ algorithms->push_back(
+ absl::make_unique<GlobalDecreasingSizeBestFitHeap>(alignment));
+ return absl::make_unique<ChooseBestHeapAlgorithm>(std::move(algorithms));
+ };
+
if (run_whole_module_heap_simulation) {
// Run the heap simulation over the whole module. This reduces memory usage,
// since buffers for kCall, kWhile, and kConditional sub-computations are
@@ -1093,8 +1106,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
options.buffers_to_assign = &buffer_value_set;
TF_ASSIGN_OR_RETURN(
const HeapSimulator::Result result,
- HeapSimulator::Run(absl::make_unique<DecreasingSizeRunsHeap>(
- absl::make_unique<LazyBestFitHeap>(alignment)),
+ HeapSimulator::Run(get_heap_algorithm(alignment),
assignment->module(), schedule,
assignment->points_to_analysis(),
assignment->buffer_size_, options));
@@ -1123,12 +1135,10 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
options.buffers_to_assign = &buffer_value_set;
TF_ASSIGN_OR_RETURN(
const HeapSimulator::Result result,
- HeapSimulator::Run(
- absl::make_unique<DecreasingSizeRunsHeap>(
- absl::make_unique<LazyBestFitHeap>(alignment)),
- *computation, HloInstructionSequence(*instruction_sequence),
- assignment->points_to_analysis(), assignment->buffer_size_,
- options));
+ HeapSimulator::Run(get_heap_algorithm(alignment), *computation,
+ HloInstructionSequence(*instruction_sequence),
+ assignment->points_to_analysis(),
+ assignment->buffer_size_, options));
AssignBuffersFromHeapSimulator(result, assignment,
single_colored_set.first);
}
diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h
index c5cd88b9ea..08c4aff4f7 100644
--- a/tensorflow/compiler/xla/service/call_inliner.h
+++ b/tensorflow/compiler/xla/service/call_inliner.h
@@ -25,7 +25,7 @@ namespace xla {
// For every kCall operation in the main computation, we inline the body of the
// called function, and proceed recursively.
-class CallInliner : public HloPassInterface {
+class CallInliner : public HloModulePass {
public:
using InlinedInstructionMap =
std::unordered_map<HloInstruction*, HloInstruction*>;
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h
index 3de50cbd7f..2223ad6753 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier.h
+++ b/tensorflow/compiler/xla/service/conditional_simplifier.h
@@ -25,7 +25,7 @@ namespace xla {
// HLO pass that removes kConditional with a constant predicate, replacing them
// with their true or false computation as appropriate.
-class ConditionalSimplifier : public HloPassInterface {
+class ConditionalSimplifier : public HloModulePass {
public:
absl::string_view name() const override { return "simplify-conditional"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
index 498894737f..ce0138e56f 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
@@ -25,7 +25,7 @@ namespace xla {
// A pass which rewrites convolutions with feature_group_count > 1 into
// convolutions with feature_group_count = 1.
-class ConvolutionFeatureGroupConverter : public HloPassInterface {
+class ConvolutionFeatureGroupConverter : public HloModulePass {
public:
ConvolutionFeatureGroupConverter() {}
diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h
index d308f6bc84..c097089e30 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.h
+++ b/tensorflow/compiler/xla/service/copy_insertion.h
@@ -43,7 +43,7 @@ namespace xla {
// (3) The buffer set of the root instruction of the entry computation must be
// unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and
// InstructionAliasSet::IsDistinct return true.
-class CopyInsertion : public HloPassInterface {
+class CopyInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "copy-insertion"; }
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 8cc522a59e..b7103118ac 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -50,6 +50,7 @@ cc_library(
"//tensorflow/compiler/xla/service/cpu:cpu_runtime",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/stream_executor",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
],
@@ -180,6 +181,7 @@ cc_library(
":runtime_conv2d_mkl",
":runtime_fft",
":runtime_fork_join",
+ ":runtime_key_value_sort",
":runtime_matmul",
":runtime_matmul_mkl",
":runtime_single_threaded_conv2d",
@@ -461,12 +463,15 @@ cc_library(
],
copts = runtime_copts(),
deps = [
+ "//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
+ "//tensorflow/stream_executor",
+ "@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
],
)
@@ -624,6 +629,18 @@ cc_library(
)
cc_library(
+ name = "runtime_key_value_sort",
+ srcs = ["runtime_key_value_sort.cc"],
+ hdrs = ["runtime_key_value_sort.h"],
+ copts = runtime_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework_lite",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_library(
name = "runtime_fork_join",
srcs = ["runtime_fork_join.cc"],
hdrs = ["runtime_fork_join.h"],
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
index 59437e88af..becee3f81f 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
@@ -31,7 +31,7 @@ namespace cpu {
// called canonical convolutions). This pass expands non-canonical convolutions
// into reshapes and canonical convolutions, so that these non-canonical
// convolutions can run faster.
-class ConvCanonicalization : public HloPassInterface {
+class ConvCanonicalization : public HloModulePass {
public:
explicit ConvCanonicalization(
const TargetMachineFeatures* target_machine_features)
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
index d49f7d7cc2..076235f887 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
@@ -30,7 +30,7 @@ namespace xla {
//
// TODO(b/62548313): Remove this when buffer assignment is smarter
// (module-scoped).
-class CpuCopyInsertion : public HloPassInterface {
+class CpuCopyInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "copy-insertion"; }
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
index 6af724b2a5..a39a9d4724 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
@@ -23,7 +23,7 @@ namespace xla {
// This pass should run early in the HLO pipeline and checks for HLO constructs
// which are not supported by the CPU backend and cannot be removed via HLO
// transformations (eg, sparse layouts).
-class CpuHloSupportChecker : public HloPassInterface {
+class CpuHloSupportChecker : public HloModulePass {
public:
CpuHloSupportChecker() = default;
~CpuHloSupportChecker() override = default;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index 8a44c384bb..20cf855735 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -17,19 +17,29 @@ limitations under the License.
#include <functional>
+#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/stream_executor.h"
namespace xla {
namespace cpu {
namespace runtime {
-XfeedManager* GetXfeedManager() {
- static XfeedManager* manager = new XfeedManager;
- return manager;
+XfeedManager* GetXfeedManager(int device_ordinal) {
+ static tensorflow::gtl::FlatMap<int, XfeedManager*>* managers =
+ new tensorflow::gtl::FlatMap<int, XfeedManager*>();
+ static absl::Mutex* mutex = new absl::Mutex();
+
+ absl::MutexLock lock(mutex);
+ auto it = managers->find(device_ordinal);
+ if (it == managers->end()) {
+ it = managers->emplace(device_ordinal, new XfeedManager()).first;
+ }
+ return it->second;
}
extern const char* const kEigenMatMulF16SymbolName =
@@ -74,6 +84,30 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName =
"__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation";
extern const char* const kParallelForkJoinSymbolName =
"__xla_cpu_runtime_ParallelForkJoin";
+extern const char* const kKeyValueSortPREDSymbolName =
+ "__xla_cpu_runtime_KeyValueSortPRED";
+extern const char* const kKeyValueSortS8SymbolName =
+ "__xla_cpu_runtime_KeyValueSortS8";
+extern const char* const kKeyValueSortU8SymbolName =
+ "__xla_cpu_runtime_KeyValueSortU8";
+extern const char* const kKeyValueSortS16SymbolName =
+ "__xla_cpu_runtime_KeyValueSortS16";
+extern const char* const kKeyValueSortU16SymbolName =
+ "__xla_cpu_runtime_KeyValueSortU16";
+extern const char* const kKeyValueSortF16SymbolName =
+ "__xla_cpu_runtime_KeyValueSortF16";
+extern const char* const kKeyValueSortS32SymbolName =
+ "__xla_cpu_runtime_KeyValueSortS32";
+extern const char* const kKeyValueSortU32SymbolName =
+ "__xla_cpu_runtime_KeyValueSortU32";
+extern const char* const kKeyValueSortF32SymbolName =
+ "__xla_cpu_runtime_KeyValueSortF32";
+extern const char* const kKeyValueSortS64SymbolName =
+ "__xla_cpu_runtime_KeyValueSortS64";
+extern const char* const kKeyValueSortU64SymbolName =
+ "__xla_cpu_runtime_KeyValueSortU64";
+extern const char* const kKeyValueSortF64SymbolName =
+ "__xla_cpu_runtime_KeyValueSortF64";
extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
} // namespace runtime
@@ -94,14 +128,18 @@ tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) {
} // namespace
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
-__xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length,
- const void* shape,
- xla::int32 shape_length) {
- if (VLOG_IS_ON(2)) {
- LOG(INFO) << "AcquireInfeedBufferForDequeue: "
- << ShapeString(shape, shape_length);
- }
- xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager();
+__xla_cpu_runtime_AcquireInfeedBufferForDequeue(
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ const void* shape, xla::int32 shape_length) {
+ int device_ordinal =
+ run_options ? run_options->stream()->parent()->device_ordinal() : 0;
+
+ VLOG(2) << "AcquireInfeedBufferForDequeue: "
+ << ShapeString(shape, shape_length) << " on stream executor "
+ << device_ordinal;
+
+ xla::cpu::runtime::XfeedManager* xfeed =
+ xla::cpu::runtime::GetXfeedManager(device_ordinal);
// Wait until there's a buffer to dequeue.
xla::cpu::runtime::XfeedBuffer* buffer =
xfeed->infeed()->BlockingDequeueBuffer();
@@ -114,15 +152,18 @@ __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length,
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
-__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length,
- void* buffer_ptr,
- const void* shape_ptr,
- xla::int32 shape_length) {
- if (VLOG_IS_ON(2)) {
- LOG(INFO) << "ReleaseInfeedBufferAfterDeque: "
- << ShapeString(shape_ptr, shape_length);
- }
- xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager();
+__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) {
+ int device_ordinal =
+ run_options ? run_options->stream()->parent()->device_ordinal() : 0;
+
+ VLOG(2) << "ReleaseInfeedBufferAfterDeque: "
+ << ShapeString(shape_ptr, shape_length) << " on stream executor "
+ << device_ordinal;
+
+ xla::cpu::runtime::XfeedManager* xfeed =
+ xla::cpu::runtime::GetXfeedManager(device_ordinal);
xla::StatusOr<xla::Shape> shape =
xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
xfeed->infeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
@@ -130,14 +171,18 @@ __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length,
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
-__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length,
- const void* shape_ptr,
- xla::int32 shape_length) {
- if (VLOG_IS_ON(2)) {
- LOG(INFO) << "AcquireOutfeedBufferForPopulation: "
- << ShapeString(shape_ptr, shape_length);
- }
- xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager();
+__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ const void* shape_ptr, xla::int32 shape_length) {
+ int device_ordinal =
+ run_options ? run_options->stream()->parent()->device_ordinal() : 0;
+
+ VLOG(2) << "AcquireOutfeedBufferForPopulation: "
+ << ShapeString(shape_ptr, shape_length) << " on stream executor "
+ << device_ordinal;
+
+ xla::cpu::runtime::XfeedManager* xfeed =
+ xla::cpu::runtime::GetXfeedManager(device_ordinal);
// Wait until there's a buffer to dequeue.
xla::cpu::runtime::XfeedBuffer* buffer =
xfeed->outfeed()->BlockingDequeueBuffer();
@@ -150,15 +195,18 @@ __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length,
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
-__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(xla::int32 buffer_length,
- void* buffer_ptr,
- const void* shape_ptr,
- xla::int32 shape_length) {
- if (VLOG_IS_ON(2)) {
- LOG(INFO) << "ReleaseOutfeedBufferAfterPopulation: "
- << ShapeString(shape_ptr, shape_length);
- }
- xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager();
+__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) {
+ int device_ordinal =
+ run_options ? run_options->stream()->parent()->device_ordinal() : 0;
+
+ VLOG(2) << "ReleaseOutfeedBufferAfterPopulation: "
+ << ShapeString(shape_ptr, shape_length) << " on stream executor "
+ << device_ordinal;
+
+ xla::cpu::runtime::XfeedManager* xfeed =
+ xla::cpu::runtime::GetXfeedManager(device_ordinal);
xla::StatusOr<xla::Shape> shape =
xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
xfeed->outfeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
index aa0e967123..b2e760a224 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
@@ -26,6 +26,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_
+#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h"
#include "tensorflow/compiler/xla/types.h"
@@ -63,13 +64,26 @@ extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName;
extern const char* const kAcquireOutfeedBufferForPopulationSymbolName;
extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName;
extern const char* const kParallelForkJoinSymbolName;
+extern const char* const kKeyValueSortPREDSymbolName;
+extern const char* const kKeyValueSortS8SymbolName;
+extern const char* const kKeyValueSortU8SymbolName;
+extern const char* const kKeyValueSortS16SymbolName;
+extern const char* const kKeyValueSortU16SymbolName;
+extern const char* const kKeyValueSortF16SymbolName;
+extern const char* const kKeyValueSortS32SymbolName;
+extern const char* const kKeyValueSortU32SymbolName;
+extern const char* const kKeyValueSortF32SymbolName;
+extern const char* const kKeyValueSortS64SymbolName;
+extern const char* const kKeyValueSortU64SymbolName;
+extern const char* const kKeyValueSortF64SymbolName;
// All symbol names for XLA CPU runtime functions need to start with this
// prefix.
extern const char* const kXlaCpuRuntimeSymbolNamePrefix;
-// Returns the infeed manager used by the CPU runtime.
-XfeedManager* GetXfeedManager();
+// Returns the infeed manager used by the CPU runtime for the CPU device
+// `device_ordinal`. Note the device ordinal does not name a CPU
+XfeedManager* GetXfeedManager(int device_ordinal);
} // namespace runtime
} // namespace cpu
@@ -77,6 +91,18 @@ XfeedManager* GetXfeedManager();
extern "C" {
+// Some things common to all of the runtime entry points below:
+//
+// * The shape pointer and shape_length reflect values that can be deserialized
+// via llvm_ir::DecodeSelfDescribingShapeConstant. This is the way we pass
+// reified type information from the generated program to the runtime, which
+// helps check the type safety and contract for the emitted-code/runtime
+// communication.
+//
+// * run_options is used to look up the device ordinal for the stream executor
+// we're executing under. If it is null the device ordinal is assumed to be
+// 0 (this behavior helps in writing tests).
+
// Note: in the runtime entry points below, the shape pointer and shape_length
// reflect values that can be deserialized via
// llvm_ir::DecodeSelfDescribingShapeConstant. This is the way we pass reified
@@ -89,7 +115,8 @@ extern "C" {
// the length would be more exact, but the length check is chosen as a
// tradeoff between error checking and speed/simplicity.
extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
- xla::int32 buffer_length, const void* shape, xla::int32 shape_length);
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ const void* shape, xla::int32 shape_length);
// Relinquishes the next infeed buffer that was returned by
// __xla_cpu_runtime_AcquireInfeedBufferForDequeue. Once this call
@@ -104,13 +131,14 @@ extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
// implemented we will add support for multiple outstanding buffers
// that can be returned out of order.
extern void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
- xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr,
- xla::int32 shape_length);
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length);
// Blocks until the next outfeed buffer is available to be populated, then
// returns it.
extern void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
- xla::int32 buffer_length, const void* shape_ptr, xla::int32 shape_length);
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ const void* shape_ptr, xla::int32 shape_length);
// Relinquishes the outfeed buffer after it has been populated.
// buffer_ptr must have been previously returned by
@@ -122,8 +150,8 @@ extern void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
// acquired, i.e., there may only be one outstanding outfeed buffer in
// use by the runtime.
extern void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
- xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr,
- xla::int32 shape_length);
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length);
} // extern "C"
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
index 5519a43b2f..1cc2844470 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/stream_executor/stream_executor.h"
namespace xla {
@@ -128,7 +129,8 @@ Status CpuTransferManager::TransferLiteralToInfeed(
buffers.push_back(buffer);
}
- cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed_manager =
+ cpu::runtime::GetXfeedManager(executor->device_ordinal());
xfeed_manager->infeed()->EnqueueBuffersAtomically(buffers);
cleanup.release();
@@ -141,7 +143,8 @@ Status CpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor,
TF_ASSIGN_OR_RETURN(cpu::runtime::XfeedBuffer * buffer,
TransferBufferToInfeedInternal(executor, size, source));
- cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed_manager =
+ cpu::runtime::GetXfeedManager(executor->device_ordinal());
xfeed_manager->infeed()->EnqueueBuffersAtomically({buffer});
return Status::OK();
@@ -265,7 +268,8 @@ StatusOr<Shape> CpuTransferManager::TransferBuffersFromOutfeedInternal(
buffer_pointers.push_back(b.get());
}
- cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed_manager =
+ cpu::runtime::GetXfeedManager(executor->device_ordinal());
xfeed_manager->outfeed()->EnqueueBuffersAtomically(buffer_pointers);
VLOG(2) << "Waiting for buffer to be notified as populated.";
std::vector<Shape> outfed_shapes;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index df8c2a636b..c3e8020783 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -404,13 +404,12 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
llvm::Value * shape_ptr,
llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_));
- // The signature of the acquire infeed buffer function is:
- //
- // (void*)(int32 length);
llvm::Type* int32_type = b_.getInt32Ty();
llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
llvm::FunctionType* acquire_type = llvm::FunctionType::get(
- i8_ptr_type, {int32_type, i8_ptr_type, int32_type},
+ i8_ptr_type,
+ {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type,
+ /*shape_ptr*/ i8_ptr_type, /*shape_length*/ int32_type},
/*isVarArg=*/false);
llvm::Function* acquire_func;
@@ -423,11 +422,11 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
}
acquire_func->setCallingConv(llvm::CallingConv::C);
- // The signature of the release infeed buffer function is:
- //
- // (void)(int32 length, void* buffer);
llvm::FunctionType* release_type = llvm::FunctionType::get(
- b_.getVoidTy(), {int32_type, i8_ptr_type, i8_ptr_type, int32_type},
+ b_.getVoidTy(),
+ {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type,
+ /*buffer_ptr*/ i8_ptr_type, /*shape_ptr*/ i8_ptr_type,
+ /*shape_length*/ int32_type},
/*isVarArg=*/false);
llvm::Function* release_func;
@@ -444,9 +443,9 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
// of size exactly 'length_32', and the runtime is responsible for
// check-failing the process if there is a mismatch, versus passing us back a
// buffer that we might overrun.
- llvm::Value* acquired_pointer =
- Call(acquire_func,
- {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)});
+ llvm::Value* acquired_pointer = Call(
+ acquire_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
+ shape_ptr, b_.getInt32(shape_length)});
if (kind == XfeedKind::kInfeed) {
// Copy to the program buffer address from the acquired buffer.
@@ -458,8 +457,8 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
/*SrcAlign=*/1, length_32);
}
- Call(release_func, {b_.getInt32(length_32), acquired_pointer, shape_ptr,
- b_.getInt32(shape_length)});
+ Call(release_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
+ acquired_pointer, shape_ptr, b_.getInt32(shape_length)});
return Status::OK();
}
@@ -495,8 +494,150 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
}
Status IrEmitter::HandleSort(HloInstruction* sort) {
- // TODO(b/26783907): Implement sort on CPU.
- return Unimplemented("Sort is not implemented on CPU.");
+ TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort));
+ auto keys = sort->operand(0);
+ auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
+ ShapeIndex keys_shape_index({});
+ ShapeIndex values_shape_index({});
+ if (values != nullptr) {
+ keys_shape_index = ShapeIndex({0});
+ values_shape_index = ShapeIndex({1});
+ }
+ auto keys_destination = GetAllocationSlice(*sort, keys_shape_index);
+ auto keys_destination_address =
+ EmitBufferPointer(keys_destination, keys->shape());
+ auto values_destination = GetAllocationSlice(*sort, values_shape_index);
+ llvm::Value* values_destination_address = nullptr;
+
+ // The sort is implemented in-place, therefore we first copy the operand
+ // buffer to the output buffer if they are not the same.
+ if (keys_destination != GetAllocationSlice(*keys)) {
+ int64 primitive_type_size =
+ ShapeUtil::ByteSizeOfPrimitiveType(keys->shape().element_type());
+ auto source_buffer = GetEmittedValueFor(keys);
+ int64 keys_size = ByteSizeOf(keys->shape());
+ MemCpy(keys_destination_address, /*DstAlign=*/primitive_type_size,
+ source_buffer,
+ /*SrcAlign=*/primitive_type_size, keys_size);
+ }
+ if (values != nullptr) {
+ values_destination_address =
+ EmitBufferPointer(values_destination, values->shape());
+ if (values_destination != GetAllocationSlice(*values)) {
+ int64 primitive_type_size =
+ ShapeUtil::ByteSizeOfPrimitiveType(values->shape().element_type());
+ auto source_buffer = GetEmittedValueFor(values);
+ int64 values_size = ByteSizeOf(values->shape());
+ MemCpy(values_destination_address, /*DstAlign=*/primitive_type_size,
+ source_buffer,
+ /*SrcAlign=*/primitive_type_size, values_size);
+ }
+ }
+
+ // Normalize the shape and the dimension to sort.
+ Shape normalized_keys_shape =
+ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
+ keys->shape());
+ int64 physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical(
+ keys->shape().layout())[sort->dimensions(0)];
+
+ int64 sort_dimension_elements =
+ normalized_keys_shape.dimensions(physical_dimension_to_sort);
+ int64 higher_dimensions = 1;
+ for (int64 i = 0; i < physical_dimension_to_sort; ++i) {
+ higher_dimensions *= normalized_keys_shape.dimensions(i);
+ }
+ int64 lower_dimensions = 1;
+ for (int64 i = ShapeUtil::Rank(normalized_keys_shape) - 1;
+ i > physical_dimension_to_sort; --i) {
+ lower_dimensions *= normalized_keys_shape.dimensions(i);
+ }
+
+ PrimitiveType keys_type = keys->shape().element_type();
+ const char* fn_name = nullptr;
+ llvm::Type* keys_native_type = nullptr;
+ switch (keys_type) {
+ case PRED:
+ fn_name = runtime::kKeyValueSortPREDSymbolName;
+ keys_native_type = b_.getInt8PtrTy();
+ break;
+ case S8:
+ fn_name = runtime::kKeyValueSortS8SymbolName;
+ keys_native_type = b_.getInt8PtrTy();
+ break;
+ case U8:
+ fn_name = runtime::kKeyValueSortU8SymbolName;
+ keys_native_type = b_.getInt8PtrTy();
+ break;
+ case S16:
+ fn_name = runtime::kKeyValueSortS16SymbolName;
+ keys_native_type = b_.getInt16Ty()->getPointerTo();
+ break;
+ case U16:
+ fn_name = runtime::kKeyValueSortU16SymbolName;
+ keys_native_type = b_.getInt16Ty()->getPointerTo();
+ break;
+ case F16:
+ fn_name = runtime::kKeyValueSortF16SymbolName;
+ keys_native_type = b_.getHalfTy()->getPointerTo();
+ break;
+ case S32:
+ fn_name = runtime::kKeyValueSortS32SymbolName;
+ keys_native_type = b_.getInt32Ty()->getPointerTo();
+ break;
+ case U32:
+ fn_name = runtime::kKeyValueSortU32SymbolName;
+ keys_native_type = b_.getInt32Ty()->getPointerTo();
+ break;
+ case F32:
+ fn_name = runtime::kKeyValueSortF32SymbolName;
+ keys_native_type = b_.getFloatTy()->getPointerTo();
+ break;
+ case S64:
+ fn_name = runtime::kKeyValueSortS64SymbolName;
+ keys_native_type = b_.getInt64Ty()->getPointerTo();
+ break;
+ case U64:
+ fn_name = runtime::kKeyValueSortU64SymbolName;
+ keys_native_type = b_.getInt64Ty()->getPointerTo();
+ break;
+ case F64:
+ fn_name = runtime::kKeyValueSortF64SymbolName;
+ keys_native_type = b_.getDoubleTy()->getPointerTo();
+ break;
+ default:
+ return Unimplemented(
+ "Element type %s not supported in the Sort op on CPU.",
+ PrimitiveType_Name(keys_type));
+ }
+
+ llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get(
+ b_.getVoidTy(),
+ {keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(),
+ b_.getInt8PtrTy(), b_.getInt32Ty()},
+ /*isVarArg=*/false);
+ auto* key_value_sort_func = llvm::cast<llvm::Function>(
+ module_->getOrInsertFunction(fn_name, key_value_sort_type));
+ key_value_sort_func->setCallingConv(llvm::CallingConv::C);
+ key_value_sort_func->setDoesNotThrow();
+ key_value_sort_func->setOnlyAccessesArgMemory();
+ Call(key_value_sort_func,
+ {PointerCast(keys_destination_address, keys_native_type),
+ b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
+ b_.getInt64(lower_dimensions),
+ values != nullptr
+ ? PointerCast(values_destination_address, b_.getInt8PtrTy())
+ : llvm::Constant::getNullValue(b_.getInt8PtrTy()),
+ b_.getInt32(values != nullptr ? ShapeUtil::ByteSizeOfPrimitiveType(
+ values->shape().element_type())
+ : 0)});
+
+ if (values != nullptr) {
+ llvm_ir::EmitTuple(GetIrArrayFor(sort),
+ {keys_destination_address, values_destination_address},
+ &b_, module_);
+ }
+ return Status::OK();
}
Status IrEmitter::HandleTuple(HloInstruction* tuple) {
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 3df99464ba..daafef4eb3 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -163,6 +163,12 @@ class IrEmitter : public DfsHloVisitorWithDefault,
Status Preprocess(HloInstruction* hlo) override;
Status Postprocess(HloInstruction* hlo) override;
+ // A convenient helper for calling BufferAssignment::GetUniqueSlice.
+ BufferAllocation::Slice GetAllocationSlice(
+ const HloInstruction& hlo, const ShapeIndex& index = {}) const {
+ return assignment_.GetUniqueSlice(&hlo, index).ConsumeValueOrDie();
+ }
+
private:
// Private helper to initialize an IR function for the computation.
void InitializeIrFunction(const string& function_name);
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
index b4c0c09ec0..ede7f433ca 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
@@ -142,6 +142,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast ||
opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed ||
opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng ||
+ opcode == HloOpcode::kSort ||
(opcode == HloOpcode::kConvolution &&
PotentiallyImplementedAsEigenConvolution(*instruction,
target_machine_features_)) ||
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
index a99cd99c14..3822d5300e 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
@@ -60,7 +60,7 @@ class ParallelTaskAssignment {
// own embedded computation, which is compiled as a parallel compute function,
// and which is invoked from a kCall instruction that is lowered in codegen to
// a runtime parallel fork/join call.
-class ParallelTaskAssigner : public HloPassInterface {
+class ParallelTaskAssigner : public HloModulePass {
public:
// 'max_parallelism': the maximum parallel task count per instruction.
// 'shape_size': shape size function used by HloCostAnalysis during parallel
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc
new file mode 100644
index 0000000000..e0e7deb98e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc
@@ -0,0 +1,236 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h"
+
+#include <algorithm>
+#include <cmath>
+#include <cstring>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/platform/dynamic_annotations.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace {
+using tensorflow::int16;
+using tensorflow::int32;
+using tensorflow::int64;
+using tensorflow::int8;
+using tensorflow::uint16;
+using tensorflow::uint32;
+using tensorflow::uint64;
+using tensorflow::uint8;
+
+template <typename KeyType>
+void KeyValueSort(std::pair<KeyType, int64>* row_to_sort, int64 num_elements) {
+ std::sort(row_to_sort, row_to_sort + num_elements);
+}
+
+// For floating point numbers, we want a total order comparator. -NaN and NaN
+// should appear at the beginning and end of the ordering, and -0.0 should
+// appear before 0.0. Also we want to have a stable sort, so if the keys are the
+// same, we compare the index values.
+template <typename KeyType>
+bool LessThan(KeyType lhs, int64 lhs_index, KeyType rhs, int64 rhs_index) {
+ bool lhs_is_negative = std::signbit(lhs);
+ bool rhs_is_negative = std::signbit(rhs);
+ // If the signs are different, we can just compare the signs.
+ if (lhs_is_negative != rhs_is_negative) {
+ return lhs_is_negative && !rhs_is_negative;
+ }
+ bool lhs_nan = std::isnan(lhs);
+ bool rhs_nan = std::isnan(rhs);
+ // Exactly one number is nan?
+ if (lhs_nan != rhs_nan) {
+ if (lhs_nan) {
+ return lhs_is_negative;
+ }
+ return !rhs_is_negative;
+ }
+ if (lhs != rhs) {
+ return lhs < rhs;
+ }
+ return lhs_index < rhs_index;
+}
+
+template <>
+void KeyValueSort(std::pair<double, int64>* row_to_sort, int64 num_elements) {
+ std::sort(row_to_sort, row_to_sort + num_elements,
+ [](const std::pair<double, int64>& lhs,
+ const std::pair<double, int64>& rhs) -> bool {
+ return LessThan(lhs.first, lhs.second, rhs.first, rhs.second);
+ });
+}
+
+template <>
+void KeyValueSort(std::pair<float, int64>* row_to_sort, int64 num_elements) {
+ std::sort(row_to_sort, row_to_sort + num_elements,
+ [](const std::pair<float, int64>& lhs,
+ const std::pair<float, int64>& rhs) -> bool {
+ return LessThan(lhs.first, lhs.second, rhs.first, rhs.second);
+ });
+}
+
+template <>
+void KeyValueSort(std::pair<Eigen::half, int64>* row_to_sort,
+ int64 num_elements) {
+ std::sort(row_to_sort, row_to_sort + num_elements,
+ [](const std::pair<Eigen::half, int64>& lhs,
+ const std::pair<Eigen::half, int64>& rhs) -> bool {
+ return LessThan(
+ Eigen::half_impl::half_to_float(lhs.first), lhs.second,
+ Eigen::half_impl::half_to_float(rhs.first), rhs.second);
+ });
+}
+
+template <typename KeyType>
+void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ // High-level idea of the iteration/sorting logic:
+ // Conceptually we have a 3-dimensional shape [a, b, c]. b corresponds to the
+ // dimension to sort, c is the product of the more minor dimensions (set to 1
+ // if b is the most minor dimension), and a is the product of the more major
+ // dimensions (set to 1 if b is the most major dimension). There are a * c
+ // many rows that we need to sort. We iterate through these, calculate a
+ // 'base_offset' value which points to the first element in that row, and add
+ // i * c for accessing the 'i'-th element in that row.
+
+ int64 sort_dimension_elements = b;
+ int64 num_iteration_elements = a * c;
+ int64 sort_dimension_offset = c;
+
+ std::unique_ptr<std::pair<KeyType, int64>[]> row_to_sort(
+ new std::pair<KeyType, int64>[sort_dimension_elements]);
+ std::unique_ptr<std::string[]> reordered_values(
+ new std::string[sort_dimension_elements]);
+ for (int64 index = 0; index < num_iteration_elements; ++index) {
+ // 'index' can be split into two values which index into the 'c' dimension
+ // and the 'a' dimension, respectively. 'index' % 'c' is the index into the
+ // 'c' dimension, 'index' / 'c' is the index into the 'a' dimension. When
+ // calculating the base offset, we need to multiply the index into the 'a'
+ // dimension with 'b' * 'c'.
+ // 'index' / 'c' * 'c' * 'b' = ('index' - 'index' % 'c') * 'b'.
+ int64 base_offset =
+ index % sort_dimension_offset +
+ (index - index % sort_dimension_offset) * sort_dimension_elements;
+ // TODO(b/26783907): We could define a custom iterator class that references
+ // both arrays. Then we could avoid the intermediate copy. However this
+ // would become more complicated, and it is not clear if the benefit is high
+ // enough.
+ for (int64 i = 0; i < sort_dimension_elements; ++i) {
+ row_to_sort[i] =
+ std::make_pair(keys[base_offset + i * sort_dimension_offset], i);
+ }
+ KeyValueSort(row_to_sort.get(), sort_dimension_elements);
+ for (int64 i = 0; i < sort_dimension_elements; ++i) {
+ keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first;
+ }
+ if (values == nullptr) {
+ continue;
+ }
+
+ // Reorder the values according to the order defined by the keys.
+ for (int64 i = 0; i < sort_dimension_elements; ++i) {
+ int64 memory_index =
+ (base_offset + row_to_sort[i].second * sort_dimension_offset) *
+ values_primitive_type_size_in_bytes;
+
+ reordered_values[i] = std::string(values + memory_index,
+ values_primitive_type_size_in_bytes);
+ }
+ for (int64 i = 0; i < sort_dimension_elements; ++i) {
+ int64 memory_index = (base_offset + i * sort_dimension_offset) *
+ values_primitive_type_size_in_bytes;
+ memcpy(values + memory_index, reordered_values[i].c_str(),
+ values_primitive_type_size_in_bytes);
+ }
+ }
+}
+} // namespace
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED(
+ bool* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8(
+ int8* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8(
+ uint8* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16(
+ int16* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16(
+ uint16* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16(
+ Eigen::half* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32(
+ int32* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32(
+ uint32* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32(
+ float* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64(
+ int64* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64(
+ uint64* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64(
+ double* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h
new file mode 100644
index 0000000000..28e35e82c1
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h
@@ -0,0 +1,88 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/platform/types.h"
+
+extern "C" {
+
+// 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b'
+// dimension of 'keys' is sorted into ascending order. 'values' can be nullptr.
+// If 'values' is not nullptr, the elements in 'values' are reordered in such a
+// way that if the element at index 'i' in 'keys' was moved to index 'j', the
+// element at index 'i' in 'values' is also moved to index 'j' (which means that
+// the same elements correspond to each other as before).
+extern void __xla_cpu_runtime_KeyValueSortPRED(
+ bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
+ char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS8(
+ tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU8(
+ tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS16(
+ tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU16(
+ tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortF16(
+ Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS32(
+ tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU32(
+ tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortF32(
+ float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
+ char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS64(
+ tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU64(
+ tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortF64(
+ double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
+ char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
+}
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index bf98064647..9ec0c8f657 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/runtime_fft.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
@@ -202,6 +203,18 @@ bool RegisterKnownJITSymbols() {
REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin);
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortPRED);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS8);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU8);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS16);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU16);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF16);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS32);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU32);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF32);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS64);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU64);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF64);
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee));
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee));
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index c55206eee7..4b129c95d4 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -180,3 +180,17 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
+
+tf_cc_test(
+ name = "cpu_key_value_sort_test",
+ srcs = ["cpu_key_value_sort_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
+ "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc
new file mode 100644
index 0000000000..3934c03a04
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc
@@ -0,0 +1,54 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
+#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+
+namespace xla {
+namespace cpu {
+namespace {
+class CpuKeyValueSortTest : public CpuCodegenTest {};
+
+TEST_F(CpuKeyValueSortTest, SortR1) {
+ const string hlo_text = R"(
+HloModule KeyValueSort
+
+ENTRY main {
+ a = f32[10] parameter(0)
+
+ ROOT result = f32[10] sort(f32[10] a), dimensions={0}
+}
+)";
+
+ string filecheck_pattern = R"(
+CHECK: call void @__xla_cpu_runtime_KeyValueSort
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_text));
+
+ CpuAotCompilationOptions options{
+ /*triple=*/"x86_64", /*cpu_name=*/"", /*features=*/"",
+ /*entry_point_name=*/"entry",
+ /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static};
+
+ CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern,
+ /*match_optimized_ir=*/true);
+}
+
+} // namespace
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc
index 8fe65f488a..cc38b81455 100644
--- a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc
@@ -66,9 +66,9 @@ void ProcessNextBuffer(int32 length) {
auto shape = ShapeUtil::MakeShape(U8, {length});
string bytes = shape.SerializeAsString();
void* buffer = __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
- length, bytes.data(), bytes.size());
- __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(length, buffer,
- bytes.data(), bytes.size());
+ /*run_options=*/nullptr, length, bytes.data(), bytes.size());
+ __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
+ /*run_options=*/nullptr, length, buffer, bytes.data(), bytes.size());
}
// Performs the acquire/release sequence on the outfeed, as the generated CPU
@@ -76,16 +76,16 @@ void ProcessNextBuffer(int32 length) {
void ProcessNextOutfeedBuffer(int32 length, const Shape& shape) {
string bytes = shape.SerializeAsString();
void* buffer = __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
- length, bytes.data(), bytes.size());
+ /*run_options=*/nullptr, length, bytes.data(), bytes.size());
__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
- length, buffer, bytes.data(), bytes.size());
+ /*run_options=*/nullptr, length, buffer, bytes.data(), bytes.size());
}
TEST_F(InfeedManagerTest, SingleThreadedSequential) {
TestInfeedBuffer* a = new TestInfeedBuffer(64);
TestInfeedBuffer* b = new TestInfeedBuffer(32);
- cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0);
xfeed->infeed()->EnqueueBuffersAtomically({a});
xfeed->infeed()->EnqueueBuffersAtomically({b});
@@ -97,7 +97,7 @@ TEST_F(InfeedManagerTest, SingleThreadedInterleaved) {
TestInfeedBuffer* a = new TestInfeedBuffer(64);
TestInfeedBuffer* b = new TestInfeedBuffer(32);
- cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0);
xfeed->infeed()->EnqueueBuffersAtomically({a});
ProcessNextBuffer(a->length());
@@ -108,7 +108,7 @@ TEST_F(InfeedManagerTest, SingleThreadedInterleaved) {
TEST_F(InfeedManagerTest, MultiThreaded) {
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "test", 2);
- cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0);
const int32 length = 64;
@@ -130,7 +130,7 @@ TEST_F(InfeedManagerTest, MultiThreaded) {
TEST_F(InfeedManagerTest, OutfeedWrongShape) {
TestInfeedBuffer* b = new TestInfeedBuffer(32, /*expect_shape_match=*/false);
- cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0);
xfeed->outfeed()->EnqueueBuffersAtomically({b});
ProcessNextOutfeedBuffer(32, ShapeUtil::MakeShape(U8, {33}));
diff --git a/tensorflow/compiler/xla/service/defuser.h b/tensorflow/compiler/xla/service/defuser.h
index c326beb899..aaa41fc4fe 100644
--- a/tensorflow/compiler/xla/service/defuser.h
+++ b/tensorflow/compiler/xla/service/defuser.h
@@ -25,7 +25,7 @@ namespace xla {
// A pass which replaces all fusion instructions with the equivalent un-fused
// instructions.
-class Defuser : public HloPassInterface {
+class Defuser : public HloModulePass {
public:
Defuser() {}
~Defuser() override {}
diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc
index ba2a674d9a..b3549acfc2 100644
--- a/tensorflow/compiler/xla/service/despecializer.cc
+++ b/tensorflow/compiler/xla/service/despecializer.cc
@@ -24,7 +24,7 @@ namespace xla {
namespace {
// Pass which strips control dependencies from all instructions in the module.
-class ControlDepRemover : public HloPassInterface {
+class ControlDepRemover : public HloModulePass {
public:
ControlDepRemover() = default;
absl::string_view name() const override { return "control-dep-remover"; }
diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h
index 7be70add2f..46dcc3a438 100644
--- a/tensorflow/compiler/xla/service/despecializer.h
+++ b/tensorflow/compiler/xla/service/despecializer.h
@@ -30,7 +30,7 @@ namespace xla {
//
// Current despecialization passes are Defuser, ImplicitBroadcastRemover,
// and BFloat16MixedPrecisionRemoval.
-class Despecializer : public HloPassInterface {
+class Despecializer : public HloModulePass {
public:
Despecializer();
absl::string_view name() const override { return "despecializer"; }
diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h
index fc38e31700..40e7a3b4c2 100644
--- a/tensorflow/compiler/xla/service/dot_decomposer.h
+++ b/tensorflow/compiler/xla/service/dot_decomposer.h
@@ -23,7 +23,7 @@ namespace xla {
// DotDecomposer is a pass which decomposes batch Dot operations into a
// sequence of smaller (R2) Dot operations.
-class DotDecomposer : public HloPassInterface {
+class DotDecomposer : public HloModulePass {
public:
// Decomposes batch Dot operations when 'decompose_batch_dot' is true.
DotDecomposer(bool decompose_batch_dot = true)
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 4bb1e071d8..515267edd7 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -847,29 +847,34 @@ llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
llvm::Value* x) {
- if (prim_type != F32) {
- // TODO(b/34339814): Implement inverse erf for F64.
+ if (prim_type != F16 && prim_type != F32 && prim_type != F64) {
return Unimplemented(
"Inverse erf is only implemented for element "
- "type F32.");
+ "types F16, F32 and F64.");
}
- auto getFloat = [&](const float f) {
- return llvm::ConstantFP::get(b_->getFloatTy(), f);
+
+ // Upcast half to float.
+ if (prim_type == F16) {
+ x = b_->CreateFPExt(x, b_->getFloatTy());
+ }
+
+ auto get_float = [&](const double f) {
+ return llvm::ConstantFP::get(x->getType(), f);
};
- auto multiply_add = [&](absl::Span<const float> coefficients,
+ auto multiply_add = [&](absl::Span<const double> coefficients,
llvm::Value* w) {
- llvm::Value* p = getFloat(coefficients.front());
+ llvm::Value* p = get_float(coefficients.front());
coefficients.remove_prefix(1);
for (float coefficient : coefficients) {
- p = FAdd(FMul(p, w), getFloat(coefficient));
+ p = FAdd(FMul(p, w), get_float(coefficient));
}
return p;
};
// Approximation for inverse error function from
// Giles, M., "Approximating the erfinv function".
- // The approximation has the form:
- // w = log((1-x)*(1+x))
+ // The approximation has the form (float version):
+ // w = -log((1-x)*(1+x))
// if ( w < 5 ) {
// w = w - 2.5
// p = sum_{i=1}^n lq[i]*w^i
@@ -879,46 +884,124 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
// }
// return p*x
llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration(
- module_, llvm::Intrinsic::log, {b_->getFloatTy()});
+ module_, llvm::Intrinsic::log, {x->getType()});
- llvm::Value* w = FNeg(
- Call(logf_fn, {FMul(FSub(getFloat(1.0f), x), FAdd(getFloat(1.0f), x))}));
+ llvm::Value* w = FNeg(Call(
+ logf_fn, {FMul(FSub(get_float(1.0f), x), FAdd(get_float(1.0f), x))}));
llvm::Value* p_addr =
- llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_);
+ llvm_ir::EmitAllocaAtFunctionEntry(x->getType(), "p.addr", b_);
+
+ if (prim_type == F16 || prim_type == F32) {
+ llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+ FCmpOLT(w, get_float(5.0f)), "w_less_than_five", b_);
+ // Handle true BB.
+ SetToFirstInsertPoint(if_data.true_block, b_);
+ {
+ llvm::Value* lw = FSub(w, get_float(2.5f));
+ absl::Span<const double> lq{
+ 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
+ -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
+ -0.00417768164f, 0.246640727f, 1.50140941f};
+ llvm::Value* p = multiply_add(lq, lw);
+ Store(p, p_addr);
+ }
- llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- FCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_);
- // Handle true BB.
- SetToFirstInsertPoint(if_data.true_block, b_);
- {
- llvm::Value* lw = FSub(w, getFloat(2.5f));
- absl::Span<const float> lq{
- 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
- -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
- -0.00417768164f, 0.246640727f, 1.50140941f};
- llvm::Value* p = multiply_add(lq, lw);
- Store(p, p_addr);
- }
+ // Handle false BB.
+ SetToFirstInsertPoint(if_data.false_block, b_);
+ {
+ llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+ module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
+
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.0f));
+ absl::Span<const double> gq{
+ -0.000200214257f, 0.000100950558f, 0.00134934322f,
+ -0.00367342844f, 0.00573950773f, -0.0076224613f,
+ 0.00943887047f, 1.00167406f, 2.83297682f};
+ llvm::Value* p = multiply_add(gq, gw);
+ Store(p, p_addr);
+ }
- // Handle false BB.
- SetToFirstInsertPoint(if_data.false_block, b_);
- {
- llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
- module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
-
- llvm::Value* gw = FSub(Call(sqrtf_fn, w), getFloat(3.0f));
- absl::Span<const float> gq{
- -0.000200214257f, 0.000100950558f, 0.00134934322f,
- -0.00367342844f, 0.00573950773f, -0.0076224613f,
- 0.00943887047f, 1.00167406f, 2.83297682f};
- llvm::Value* p = multiply_add(gq, gw);
- Store(p, p_addr);
- }
+ SetToFirstInsertPoint(if_data.after_block, b_);
+ } else {
+ DCHECK(prim_type == F64);
+
+ llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+ FCmpOLT(w, get_float(6.25)), "w_less_than_6.25", b_);
+
+ SetToFirstInsertPoint(if_data.true_block, b_);
+ {
+ llvm::Value* lw = FSub(w, get_float(3.125));
+ absl::Span<const double> c{
+ -3.6444120640178196996e-21, -1.685059138182016589e-19,
+ 1.2858480715256400167e-18, 1.115787767802518096e-17,
+ -1.333171662854620906e-16, 2.0972767875968561637e-17,
+ 6.6376381343583238325e-15, -4.0545662729752068639e-14,
+ -8.1519341976054721522e-14, 2.6335093153082322977e-12,
+ -1.2975133253453532498e-11, -5.4154120542946279317e-11,
+ 1.051212273321532285e-09, -4.1126339803469836976e-09,
+ -2.9070369957882005086e-08, 4.2347877827932403518e-07,
+ -1.3654692000834678645e-06, -1.3882523362786468719e-05,
+ 0.0001867342080340571352, -0.00074070253416626697512,
+ -0.0060336708714301490533, 0.24015818242558961693,
+ 1.6536545626831027356};
+ llvm::Value* p = multiply_add(c, lw);
+ Store(p, p_addr);
+ }
- SetToFirstInsertPoint(if_data.after_block, b_);
+ SetToFirstInsertPoint(if_data.false_block, b_);
+ llvm_ir::LlvmIfData if_data_second = llvm_ir::EmitIfThenElse(
+ FCmpOLT(w, get_float(16.0)), "w_less_than_16", b_);
+ SetToFirstInsertPoint(if_data_second.true_block, b_);
+ {
+ llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+ module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()});
+
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.25));
+ absl::Span<const double> t1{
+ 2.2137376921775787049e-09, 9.0756561938885390979e-08,
+ -2.7517406297064545428e-07, 1.8239629214389227755e-08,
+ 1.5027403968909827627e-06, -4.013867526981545969e-06,
+ 2.9234449089955446044e-06, 1.2475304481671778723e-05,
+ -4.7318229009055733981e-05, 6.8284851459573175448e-05,
+ 2.4031110387097893999e-05, -0.0003550375203628474796,
+ 0.00095328937973738049703, -0.0016882755560235047313,
+ 0.0024914420961078508066, -0.0037512085075692412107,
+ 0.005370914553590063617, 1.0052589676941592334,
+ 3.0838856104922207635};
+ llvm::Value* p = multiply_add(t1, gw);
+ Store(p, p_addr);
+ }
+
+ SetToFirstInsertPoint(if_data_second.false_block, b_);
+ {
+ llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+ module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()});
+
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(5.0));
+ absl::Span<const double> t2{
+ -2.7109920616438573243e-11, -2.5556418169965252055e-10,
+ 1.5076572693500548083e-09, -3.7894654401267369937e-09,
+ 7.6157012080783393804e-09, -1.4960026627149240478e-08,
+ 2.9147953450901080826e-08, -6.7711997758452339498e-08,
+ 2.2900482228026654717e-07, -9.9298272942317002539e-07,
+ 4.5260625972231537039e-06, -1.9681778105531670567e-05,
+ 7.5995277030017761139e-05, -0.00021503011930044477347,
+ -0.00013871931833623122026, 1.0103004648645343977,
+ 4.8499064014085844221};
+ llvm::Value* p = multiply_add(t2, gw);
+ Store(p, p_addr);
+ }
+
+ SetToFirstInsertPoint(if_data.after_block, b_);
+ }
llvm::Value* p = Load(p_addr);
- return FMul(p, x);
+ x = FMul(p, x);
+ // Trunc back to half if needed.
+ if (prim_type == F16) {
+ x = b_->CreateFPTrunc(x, b_->getHalfTy());
+ }
+ return x;
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(PrimitiveType prim_type,
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.h b/tensorflow/compiler/xla/service/flatten_call_graph.h
index 3cccec9862..986970f886 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph.h
+++ b/tensorflow/compiler/xla/service/flatten_call_graph.h
@@ -26,7 +26,7 @@ namespace xla {
// Flattening associates each call site with a unique computation (for
// sequential calling contexts) This simplifies buffer assignment and
// points-to analysis (see b/36865746 for details).
-class FlattenCallGraph : public HloPassInterface {
+class FlattenCallGraph : public HloModulePass {
public:
absl::string_view name() const override { return "flatten-call-graph"; }
diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h
index 7bd9ea5984..2b39359aae 100644
--- a/tensorflow/compiler/xla/service/gather_expander.h
+++ b/tensorflow/compiler/xla/service/gather_expander.h
@@ -23,7 +23,7 @@ namespace xla {
// This pass rewrites gather operations into (roughly) while loops of dynamic
// slices. This lets backends that don't support gather directly to
// nevertheless have a minimum level of support.
-class GatherExpander : public HloPassInterface {
+class GatherExpander : public HloModulePass {
public:
absl::string_view name() const override { return "gather_expander"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 64b9683628..51968d13d4 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -68,9 +68,7 @@ cc_library(
# srcs = [
# "partition_assignment_test.cc",
# ],
-# tags = [
-# "requires-gpu-sm35",
-# ],
+# tags = tf_cuda_tests_tags(),
# deps = [
# ":partition_assignment",
# "//tensorflow/core:stream_executor_no_cuda",
@@ -373,7 +371,6 @@ cc_library(
hdrs = ["ir_emission_utils.h"],
deps = [
":backend_configs",
- ":cudnn_convolution_runner",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
@@ -414,6 +411,8 @@ cc_library(
srcs = ["cudnn_convolution_runner.cc"],
hdrs = ["cudnn_convolution_runner.h"],
deps = [
+ ":backend_configs",
+ ":ir_emission_utils",
":stream_executor_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
@@ -422,8 +421,10 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -432,6 +433,7 @@ cc_library(
srcs = ["cudnn_convolution_rewriter.cc"],
hdrs = ["cudnn_convolution_rewriter.h"],
deps = [
+ ":backend_configs",
":ir_emission_utils",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:util",
@@ -596,14 +598,11 @@ cc_library(
hdrs = ["pad_for_tensor_cores.h"],
deps = [
":ir_emission_utils",
- "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/service:hlo_creation_utils",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_pass",
- "//tensorflow/compiler/xla/service:shape_inference",
],
)
@@ -656,6 +655,7 @@ cc_library(
deps = [
":cudnn_convolution_algorithm_picker",
":cudnn_convolution_rewriter",
+ ":cudnn_fused_convolution_rewriter",
":fusion_merger",
":gpu_constants",
":gpu_copy_insertion",
@@ -783,6 +783,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:layout_assignment",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
@@ -967,3 +968,19 @@ tf_cc_test(
"@com_google_absl//absl/strings",
],
)
+
+cc_library(
+ name = "cudnn_fused_convolution_rewriter",
+ srcs = ["cudnn_fused_convolution_rewriter.cc"],
+ hdrs = ["cudnn_fused_convolution_rewriter.h"],
+ deps = [
+ ":backend_configs",
+ ":ir_emission_utils",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
+ "//tensorflow/compiler/xla/service:hlo_pass",
+ "//tensorflow/compiler/xla/service:pattern_matcher",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto
index 640c6392b8..78e14d860e 100644
--- a/tensorflow/compiler/xla/service/gpu/backend_configs.proto
+++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto
@@ -24,4 +24,18 @@ message CudnnConvBackendConfig {
// true, cudnn may choose not to use tensor cores, e.g. because the GPU or
// selected algorithm doesn't support it.
bool tensor_ops_enabled = 2;
+
+ // The scaling factor multiplied with the convolution result.
+ double conv_result_scale = 4;
+
+ // Below are the fields related to cuDNN's fused convolution. Refer to
+ // CudnnConvParams for their meanings.
+
+ // The requested activation (e.g. relu) after the convolution. It is with type
+ // stream_executor::dnn::ActivationMode.
+ int64 activation_mode = 3;
+
+ // The scaling factor multiplied with the side input. If no side input buffer
+ // is provided, this field must be 0.
+ double side_input_scale = 5;
}
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 3a23ac1d63..4effea637d 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -29,37 +29,38 @@ limitations under the License.
namespace xla {
namespace gpu {
-using se::dnn::AlgorithmDesc;
+ConvolutionThunk::ConvolutionThunk(
+ const HloCustomCallInstruction* cudnn_call,
+ std::vector<BufferAllocation::Slice> operand_slices,
+ BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice,
+ BufferAllocation::Slice tuple_result_slice)
+ : Thunk(Kind::kConvolution, cudnn_call),
+ cudnn_call_(cudnn_call),
+ operand_buffers_(std::move(operand_slices)),
+ result_buffer_(result_slice),
+ scratch_buffer_(scratch_slice),
+ tuple_result_buffer_(tuple_result_slice) {}
Status ConvolutionThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
- CudnnConvParams params;
+ std::vector<se::DeviceMemoryBase> operand_se_buffers;
+ for (const auto& buffer : operand_buffers_) {
+ operand_se_buffers.push_back(buffer_allocations.GetDeviceAddress(buffer));
+ }
+
+ se::DeviceMemoryBase result_buffer =
+ buffer_allocations.GetDeviceAddress(result_buffer_);
- params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_);
- params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_);
- params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_);
se::DeviceMemoryBase scratch =
buffer_allocations.GetDeviceAddress(scratch_buffer_);
- TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, &params));
-
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
- TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream));
+ TF_RETURN_IF_ERROR(RunCudnnConvolution(cudnn_call_,
+ absl::MakeSpan(operand_se_buffers),
+ result_buffer, scratch, stream));
- // Figure out which of output/input/filter is the result produced by
- // this op, and write the result tuple.
- void* result_ptr = [&] {
- switch (params.kind) {
- case CudnnConvKind::kForward:
- return params.output_buf.opaque();
- case CudnnConvKind::kBackwardInput:
- return params.input_buf.opaque();
- case CudnnConvKind::kBackwardFilter:
- return params.filter_buf.opaque();
- }
- }();
- void* ptrs[] = {result_ptr, scratch.opaque()};
+ void* ptrs[] = {result_buffer.opaque(), scratch.opaque()};
se::DeviceMemory<void*> tuple_addr(
buffer_allocations.GetDeviceAddress(tuple_result_buffer_));
stream->ThenMemcpyH2D<void*>(ptrs, &tuple_addr);
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index d7d1f91fba..f53bc54198 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -42,24 +42,12 @@ class ConvolutionThunk : public Thunk {
// Constructs a thunk for launching a DNN convolution. When run, it will
// write a tuple (result, scratch_memory) into `tuple_result_buffer`.
//
- // Note that "output" here doesn't refer to the output from running this
- // thunk, but rather to the "output" of a hypothetical forward convolution
- // that corresponds to this input+filter+output triple. That is, the result
- // generated by this thunk is "output" for forward convs, "input" for
- // backward-input convs, and "filter" for backward-filter convs.
+ // operand_slices should be in the same order as cudnn_call->operands().
ConvolutionThunk(const HloCustomCallInstruction* cudnn_call,
- BufferAllocation::Slice input_slice,
- BufferAllocation::Slice filter_slice,
- BufferAllocation::Slice output_slice,
+ std::vector<BufferAllocation::Slice> operand_slices,
+ BufferAllocation::Slice result_slice,
BufferAllocation::Slice scratch_slice,
- BufferAllocation::Slice tuple_result_slice)
- : Thunk(Kind::kConvolution, cudnn_call),
- cudnn_call_(cudnn_call),
- input_buffer_(std::move(input_slice)),
- filter_buffer_(std::move(filter_slice)),
- output_buffer_(std::move(output_slice)),
- scratch_buffer_(std::move(scratch_slice)),
- tuple_result_buffer_(std::move(tuple_result_slice)) {}
+ BufferAllocation::Slice tuple_result_slice);
ConvolutionThunk(const ConvolutionThunk&) = delete;
ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
@@ -71,9 +59,8 @@ class ConvolutionThunk : public Thunk {
private:
const HloCustomCallInstruction* cudnn_call_;
- BufferAllocation::Slice input_buffer_;
- BufferAllocation::Slice filter_buffer_;
- BufferAllocation::Slice output_buffer_;
+ std::vector<BufferAllocation::Slice> operand_buffers_;
+ BufferAllocation::Slice result_buffer_;
BufferAllocation::Slice scratch_buffer_;
BufferAllocation::Slice tuple_result_buffer_;
};
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
index 6e2e330edd..c3f58508dd 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
@@ -52,7 +52,7 @@ namespace gpu {
// The GPU backend does not implement a lowering for the batchnorm HLOs -- it
// expects them to be lowered to cudnn calls via this pass or to HLO soup via
// BatchNormRewriter.
-class CudnnBatchNormRewriter : public HloPassInterface {
+class CudnnBatchNormRewriter : public HloModulePass {
public:
absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index f528e62b17..7125673887 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -76,54 +76,24 @@ StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
return se::DeviceMemory<uint8>(buffer_addr);
}
-// Determines whether we can safely perform a winograd non-fused convolution for
-// the given input and output shapes. This works around b/68264959, an integer
-// overflow in cuDNNv5 and cuDNNv6.
-bool ShouldIncludeWinogradNonfusedAlgo(const Shape& input_shape,
- const Shape& output_shape,
- const ConvolutionDimensionNumbers& dnums,
- se::StreamExecutor* stream_exec) {
- // Skip this check for cudnn7 and newer.
- auto version = stream_exec->AsDnn()->GetVersion();
- if (version.ok() && version.ValueOrDie().major_version() >= 7) {
- return true;
- }
-
- int64 batch = input_shape.dimensions(dnums.input_batch_dimension());
- int64 in_depths = input_shape.dimensions(dnums.input_feature_dimension());
- int64 in_rows = input_shape.dimensions(dnums.input_spatial_dimensions(0));
- int64 in_cols =
- dnums.input_spatial_dimensions_size() == 1
- ? 1
- : input_shape.dimensions(dnums.input_spatial_dimensions(1));
- int64 out_depths = output_shape.dimensions(dnums.output_feature_dimension());
-
- int64 total_size = CeilOfRatio(batch, int64{16}) *
- std::max(in_depths, out_depths) * in_cols * in_rows *
- sizeof(float);
-
- const int64 threshold = 1L << 31;
- return total_size < threshold;
-}
-
std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
- bool with_winograd_nonfused,
se::StreamExecutor* stream_exec) {
std::vector<AlgorithmDesc> algorithms;
+ bool succ = false;
switch (kind) {
case CudnnConvKind::kBackwardFilter:
- CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms(
- with_winograd_nonfused, &algorithms));
+ succ =
+ stream_exec->GetConvolveBackwardFilterAlgorithms(true, &algorithms);
break;
case CudnnConvKind::kBackwardInput:
- CHECK(stream_exec->GetConvolveBackwardDataAlgorithms(
- with_winograd_nonfused, &algorithms));
+ succ = stream_exec->GetConvolveBackwardDataAlgorithms(true, &algorithms);
break;
case CudnnConvKind::kForward:
- CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused,
- &algorithms));
+ case CudnnConvKind::kForwardActivation:
+ succ = stream_exec->GetConvolveAlgorithms(true, &algorithms);
break;
}
+ DCHECK(succ);
return algorithms;
}
@@ -177,19 +147,11 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
// caching would speed up compilation a lot.
StatusOr<std::tuple<int64, bool, int64>>
CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
- const HloCustomCallInstruction* instr) {
- CudnnConvParams params;
- TF_RETURN_IF_ERROR(PopulateCudnnConvParams(instr, &params));
-
- const Shape& input_shape = *params.input_shape;
- const Shape& filter_shape = *params.filter_shape;
- const Shape& output_shape = *params.output_shape;
-
- CHECK_EQ(input_shape.element_type(), filter_shape.element_type());
- CHECK_EQ(input_shape.element_type(), output_shape.element_type());
+ HloCustomCallInstruction* instr) {
// TODO(timshen): for now only check fp16. It can be expanded to other types,
// with some work on the HLO routines.
- const bool cross_check_enabled = input_shape.element_type() == xla::F16;
+ const bool cross_check_enabled =
+ instr->shape().tuple_shapes(0).element_type() == xla::F16;
// Don't run this function concurrently on the same GPU.
//
@@ -257,51 +219,43 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
// use a ScratchAllocator for this instead of calling allocator_ directly so
// that our allocations don't leak.
ScratchAllocator input_output_allocator(device_ordinal, allocator);
- TF_ASSIGN_OR_RETURN(params.input_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(input_shape)));
- TF_ASSIGN_OR_RETURN(params.filter_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(filter_shape)));
- TF_ASSIGN_OR_RETURN(params.output_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(output_shape)));
-
- initialize_buffer(params.input_buf);
- initialize_buffer(params.filter_buf);
- initialize_buffer(params.output_buf);
-
- DeviceMemoryBase* result_buf = [&] {
- switch (params.kind) {
- case CudnnConvKind::kBackwardFilter:
- return &params.filter_buf;
- case CudnnConvKind::kBackwardInput:
- return &params.input_buf;
- case CudnnConvKind::kForward:
- return &params.output_buf;
- }
- }();
+ std::vector<se::DeviceMemoryBase> operand_buffers;
+ for (const auto* operand : instr->operands()) {
+ TF_ASSIGN_OR_RETURN(auto buffer,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(operand->shape())));
+ initialize_buffer(buffer);
+ operand_buffers.push_back(buffer);
+ }
+ TF_ASSIGN_OR_RETURN(
+ auto result_buffer,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0))));
+ initialize_buffer(result_buffer);
- const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo(
- input_shape, output_shape, *params.dnums, stream_exec_);
se::dnn::ProfileResult best_result;
int64 best_result_bytes_used = 0;
+ TF_ASSIGN_OR_RETURN(auto backend_config,
+ instr->backend_config<CudnnConvBackendConfig>());
optional<F16BufferComparator> comparator;
// Use the first algorithm that's supported as reference. There isn't a
// particular reason to use it, as any algorithm sufficies. It doesn't make
// this algorithm considered correct, though.
optional<AlgorithmDesc> first_algorithm;
- for (const AlgorithmDesc& alg :
- GetAlgorithms(params.kind, use_winograd_nonfused, stream_exec_)) {
+ TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr));
+ for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) {
ScratchAllocator scratch_allocator(device_ordinal, allocator);
se::dnn::ProfileResult profile_result;
VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
<< instr->ToString();
- params.algorithm = AlgorithmConfig(alg);
- bool launch_ok = RunCudnnConvolution(params, &scratch_allocator, &stream,
- &profile_result)
+ backend_config.set_algorithm(alg.algo_id());
+ backend_config.set_tensor_ops_enabled(alg.tensor_ops_enabled());
+ TF_RETURN_IF_ERROR(instr->set_backend_config(backend_config));
+ bool launch_ok = RunCudnnConvolution(instr, absl::MakeSpan(operand_buffers),
+ result_buffer, &scratch_allocator,
+ &stream, &profile_result)
.ok();
if (launch_ok && profile_result.is_valid()) {
@@ -312,7 +266,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
.xla_gpu_crash_on_verification_failures();
if (comparator.has_value()) {
StatusOr<bool> result = comparator->CompareEqual(
- se::DeviceMemory<Eigen::half>(*result_buf));
+ se::DeviceMemory<Eigen::half>(result_buffer));
if (!result.ok()) {
LOG(ERROR) << "Unable to compare "
<< AlgorithmToString(*first_algorithm) << " against "
@@ -330,7 +284,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
}
} else if (cross_check_enabled) {
auto comp = F16BufferComparator::Create(
- se::DeviceMemory<Eigen::half>(*result_buf), compiler_, allocator,
+ se::DeviceMemory<Eigen::half>(result_buffer), compiler_, allocator,
&stream);
if (comp.ok()) {
comparator.emplace(comp.ConsumeValueOrDie());
@@ -404,13 +358,14 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0),
ShapeUtil::MakeShape(U8, {scratch_bytes})});
- CudnnConvBackendConfig backend_config;
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
+ instr->backend_config<CudnnConvBackendConfig>());
backend_config.set_algorithm(algorithm);
backend_config.set_tensor_ops_enabled(tensor_ops_enabled);
HloInstruction* new_call = computation->AddInstruction(
- instr->CloneWithNewOperands(new_call_shape, {instr->mutable_operand(0),
- instr->mutable_operand(1)}));
+ instr->CloneWithNewOperands(new_call_shape, instr->operands()));
+
TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config));
// Repackage new_call so it has the same shape as the original call, namely
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index f79b113f8f..aeda2fc7f8 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -30,7 +30,7 @@ namespace gpu {
// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for
// each and adding explicit scratch space to the CustomCalls.
-class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
+class CudnnConvolutionAlgorithmPicker : public HloModulePass {
public:
// If the `allocator` parameter is not null, we will use it to allocate temp
// memory while timing the various convolution algorithms. If it's null,
@@ -50,7 +50,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
StatusOr<bool> RunOnComputation(HloComputation* computation);
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm(
- const HloCustomCallInstruction* instr);
+ HloCustomCallInstruction* instr);
se::StreamExecutor* stream_exec_; // never null
DeviceMemoryAllocator* allocator_; // may be null
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index 228379a248..ef29237301 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -35,6 +36,32 @@ namespace gpu {
namespace {
+HloInstruction* CreateCudnnConv(const char* call_target, const Shape& shape,
+ HloInstruction* lhs, HloInstruction* rhs,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count) {
+ HloComputation* computation = lhs->parent();
+
+ // This call returns a tuple of (conv_result, scratch_memory), where
+ // conv_result is the actual result of the convolution, and scratch_memory is
+ // temporary memory used by cudnn.
+ //
+ // At the moment, we don't know how much scratch memory this conv is going to
+ // use, so we put u8[0] in this place. Later on another pass will choose
+ // which conv algorithm to use, and at that point we'll modify the shape of
+ // this second tuple element.
+ Shape call_shape =
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})});
+
+ HloInstruction* custom_call = computation->AddInstruction(
+ HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target));
+ custom_call->set_window(window);
+ custom_call->set_convolution_dimension_numbers(dnums);
+ custom_call->set_feature_group_count(feature_group_count);
+ return custom_call;
+}
+
bool CanImplementAsCudnnForwardConv(HloInstruction* conv) {
const ConvolutionDimensionNumbers& dnums =
conv->convolution_dimension_numbers();
@@ -450,6 +477,12 @@ MatchBackwardInput(HloInstruction* conv) {
return std::make_tuple(true, new_window, dnums, rhs);
}
+CudnnConvBackendConfig GetDefaultBackendConfig() {
+ CudnnConvBackendConfig config;
+ config.set_conv_result_scale(1);
+ return config;
+}
+
// Tries to rewrite a single convolution into a call to cudnn.
StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
@@ -462,24 +495,24 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
std::tie(match, window, dnums) = MatchBackwardFilter(conv);
if (match) {
- return CreateCudnnConvBackwardFilter(
- conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1),
- window, dnums, conv->feature_group_count());
+ return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, conv->shape(),
+ conv->mutable_operand(0), conv->mutable_operand(1),
+ window, dnums, conv->feature_group_count());
}
std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv);
if (match) {
- return CreateCudnnConvBackwardInput(conv->shape(),
- conv->mutable_operand(0), rhs, window,
- dnums, conv->feature_group_count());
+ return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, conv->shape(),
+ conv->mutable_operand(0), rhs, window, dnums,
+ conv->feature_group_count());
}
// If all else fails, try a forward convolution.
if (CanImplementAsCudnnForwardConv(conv)) {
- return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0),
- conv->mutable_operand(1), conv->window(),
- conv->convolution_dimension_numbers(),
- conv->feature_group_count());
+ return CreateCudnnConv(
+ kCudnnConvForwardCallTarget, conv->shape(), conv->mutable_operand(0),
+ conv->mutable_operand(1), conv->window(),
+ conv->convolution_dimension_numbers(), conv->feature_group_count());
}
return nullptr;
@@ -489,6 +522,9 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
return false;
}
+ TF_RETURN_IF_ERROR(
+ custom_call->set_backend_config(GetDefaultBackendConfig()));
+
// The CustomCall returns a tuple (conv_result, scratch_memory). Extract out
// the conv result and replace `conv` with it.
TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
index fbe7e98494..8d7c6fdab5 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
@@ -24,7 +24,7 @@ namespace gpu {
// Rewrites plain convolutions, backwards-filter convolutions, and
// backwards-input convolutions into CustomCall HLOs that call into cuDNN.
-class CudnnConvolutionRewriter : public HloPassInterface {
+class CudnnConvolutionRewriter : public HloModulePass {
public:
absl::string_view name() const override {
return "cudnn-convolution-rewriter";
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 2a86ac265e..89dd1bb272 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -37,6 +39,42 @@ using se::dnn::FilterDescriptor;
using se::dnn::FilterLayout;
using se::dnn::ProfileResult;
+struct CudnnConvParams {
+ // Here are the fields related to cuDNN's fused convolution. The result thus
+ // is defined as:
+ // activation(conv_result_scale * conv(x, w) +
+ // side_input_scale * side_input + broadcast(bias))
+ //
+ // The most common fused conv is conv forward + relu/identity, for example.
+ //
+ // bias_buf is a single-dimensional array, with the length equal to the number
+ // of output features. It'll be broadcasted to the output shape in order to be
+ // added to the final results.
+ //
+ // side_input_buf, if valid, must have the same shape as the output buffer.
+ struct FusionParams {
+ se::dnn::ActivationMode mode;
+ double side_input_scale;
+ se::DeviceMemoryBase bias_buf;
+ se::DeviceMemoryBase side_input_buf; // nullable
+ };
+
+ CudnnConvKind kind;
+ const Shape* input_shape;
+ const Shape* filter_shape;
+ const Shape* output_shape;
+ se::DeviceMemoryBase input_buf;
+ se::DeviceMemoryBase filter_buf;
+ se::DeviceMemoryBase output_buf;
+ const Window* window;
+ const ConvolutionDimensionNumbers* dnums;
+ int64 feature_group_count;
+ se::dnn::AlgorithmConfig algorithm;
+ double conv_result_scale;
+
+ absl::optional<FusionParams> fusion;
+};
+
// A StreamExecutor ScratchAllocator that wraps a single XLA allocation,
// returning it (in its entirety) the first time Allocate() is called.
class ScratchBufAllocator : public se::ScratchAllocator {
@@ -92,9 +130,9 @@ Status RunCudnnConvolutionImpl(CudnnConvParams params,
VLOG(3) << "tensor_ops_enabled: "
<< algorithm.algorithm().tensor_ops_enabled();
VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind);
- VLOG(3) << "input shape: { " << ShapeUtil::HumanString(input_shape) << " }";
- VLOG(3) << "filter shape: { " << ShapeUtil::HumanString(filter_shape) << " }";
- VLOG(3) << "Output shape: { " << ShapeUtil::HumanString(output_shape) << " }";
+ VLOG(3) << "input shape: " << ShapeUtil::HumanStringWithLayout(input_shape);
+ VLOG(3) << "filter shape: " << ShapeUtil::HumanStringWithLayout(filter_shape);
+ VLOG(3) << "Output shape: " << ShapeUtil::HumanStringWithLayout(output_shape);
VLOG(3) << "Window: { " << window.ShortDebugString() << " }";
VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }";
@@ -186,23 +224,73 @@ Status RunCudnnConvolutionImpl(CudnnConvParams params,
switch (kind) {
case CudnnConvKind::kForward:
+ if (params.conv_result_scale != 1) {
+ return InternalError(
+ "StreamExecutor doesn't support scaled convolution: %lf.",
+ params.conv_result_scale);
+ }
stream->ThenConvolveWithAlgorithm(
input_descriptor, input_buf, filter_descriptor, filter_buf,
convolution_descriptor, output_descriptor, &output_buf,
scratch_allocator, algorithm, profile_result);
break;
case CudnnConvKind::kBackwardInput:
+ if (params.conv_result_scale != 1) {
+ return InternalError(
+ "StreamExecutor doesn't support scaled convolution: %lf.",
+ params.conv_result_scale);
+ }
stream->ThenConvolveBackwardDataWithAlgorithm(
filter_descriptor, filter_buf, output_descriptor, output_buf,
convolution_descriptor, input_descriptor, &input_buf,
scratch_allocator, algorithm, profile_result);
break;
case CudnnConvKind::kBackwardFilter:
+ if (params.conv_result_scale != 1) {
+ return InternalError(
+ "StreamExecutor doesn't support scaled convolution: %lf.",
+ params.conv_result_scale);
+ }
stream->ThenConvolveBackwardFilterWithAlgorithm(
input_descriptor, input_buf, output_descriptor, output_buf,
convolution_descriptor, filter_descriptor, &filter_buf,
scratch_allocator, algorithm, profile_result);
break;
+ case CudnnConvKind::kForwardActivation: {
+ BatchDescriptor bias_desc;
+ bias_desc.set_count(1)
+ .set_height(1)
+ .set_width(1)
+ .set_feature_map_count(
+ output_shape.dimensions(dnums.output_feature_dimension()))
+ .set_layout(output_dl);
+
+ se::DeviceMemory<T> side_input(params.fusion->side_input_buf);
+ // If there is no side input, use output as the side input.
+ if (side_input.is_null()) {
+ if (params.fusion->side_input_scale != 0) {
+ return InternalError(
+ "Side input scale is not 0, yet no side input buffer is "
+ "provided");
+ }
+ // Since side-input scale is 0, the values in the side input don't
+ // matter. The simplest thing to do would be to pass in a null buffer
+ // for the side input, but cudnn doesn't allow this. cudnn does promise
+ // that if side-input-scale is 0 the side input won't be read, so we
+ // just pass in the output buffer, since it's handy and has the correct
+ // size.
+ side_input = output_buf;
+ }
+
+ stream->ThenFusedConvolveWithAlgorithm(
+ input_descriptor, input_buf, params.conv_result_scale,
+ filter_descriptor, filter_buf, convolution_descriptor, side_input,
+ params.fusion->side_input_scale, bias_desc,
+ DeviceMemory<T>(params.fusion->bias_buf), params.fusion->mode,
+ output_descriptor, &output_buf, scratch_allocator, algorithm,
+ profile_result);
+ break;
+ }
}
if (!stream->ok()) {
@@ -214,32 +302,104 @@ Status RunCudnnConvolutionImpl(CudnnConvParams params,
return Status::OK();
}
-} // anonymous namespace
+// Returns the cudnn convolution parameters generated from conv, which must be a
+// custom-call to a cudnn convolution.
+StatusOr<CudnnConvParams> GetCudnnConvParams(
+ const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer) {
+ CudnnConvParams params;
-string CudnnConvKindToString(CudnnConvKind kind) {
- switch (kind) {
- case CudnnConvKind::kForward:
- return "forward";
- case CudnnConvKind::kBackwardFilter:
- return "backward_filter";
- case CudnnConvKind::kBackwardInput:
- return "backward_input";
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
+ conv->backend_config<CudnnConvBackendConfig>());
+ const auto& target = conv->custom_call_target();
+ const auto& lhs_shape = conv->operand(0)->shape();
+ const auto& rhs_shape = conv->operand(1)->shape();
+ const auto& conv_result_shape = conv->shape().tuple_shapes(0);
+
+ params.window = &conv->window();
+ params.dnums = &conv->convolution_dimension_numbers();
+ params.feature_group_count = conv->feature_group_count();
+ params.algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc(
+ backend_config.algorithm(), backend_config.tensor_ops_enabled()));
+ params.conv_result_scale = backend_config.conv_result_scale();
+
+ if (target == kCudnnConvForwardCallTarget) {
+ params.kind = CudnnConvKind::kForward;
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &conv_result_shape;
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = result_buffer;
+ } else if (target == kCudnnConvBackwardInputCallTarget) {
+ params.kind = CudnnConvKind::kBackwardInput;
+ params.input_shape = &conv_result_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &lhs_shape;
+ params.input_buf = result_buffer;
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = operand_buffers[0];
+ } else if (target == kCudnnConvBackwardFilterCallTarget) {
+ params.kind = CudnnConvKind::kBackwardFilter;
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &conv_result_shape;
+ params.output_shape = &rhs_shape;
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = result_buffer;
+ params.output_buf = operand_buffers[1];
+ } else if (target == kCudnnConvBiasActivationForwardCallTarget) {
+ params.kind = CudnnConvKind::kForwardActivation;
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &conv_result_shape;
+ params.fusion.emplace();
+ auto& fusion = *params.fusion;
+ if (backend_config.activation_mode() <
+ static_cast<int64>(se::dnn::ActivationMode::kNumActivationModes)) {
+ fusion.mode = static_cast<se::dnn::ActivationMode>(
+ backend_config.activation_mode());
+ } else {
+ return InternalError("Bad activation mode: %s",
+ backend_config.ShortDebugString());
+ }
+ fusion.side_input_scale = backend_config.side_input_scale();
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = result_buffer;
+ params.fusion->bias_buf = operand_buffers[2];
+ if (operand_buffers.size() >= 4) {
+ params.fusion->side_input_buf = operand_buffers[3];
+ }
+ } else {
+ return InternalError("Unexpected custom call target: %s", target);
}
+ return params;
}
-Status RunCudnnConvolution(CudnnConvParams params,
+} // anonymous namespace
+
+Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer,
se::DeviceMemoryBase scratch_buf, se::Stream* stream,
se::dnn::ProfileResult* profile_result) {
ScratchBufAllocator scratch_allocator(scratch_buf);
- return RunCudnnConvolution(params, &scratch_allocator, stream,
- profile_result);
+ return RunCudnnConvolution(conv, operand_buffers, result_buffer,
+ &scratch_allocator, stream, profile_result);
}
-Status RunCudnnConvolution(CudnnConvParams params,
+Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer,
se::ScratchAllocator* scratch_allocator,
se::Stream* stream,
se::dnn::ProfileResult* profile_result) {
- PrimitiveType output_primitive_type = params.output_shape->element_type();
+ TF_ASSIGN_OR_RETURN(CudnnConvParams params,
+ GetCudnnConvParams(conv, operand_buffers, result_buffer));
+
+ PrimitiveType output_primitive_type =
+ conv->shape().tuple_shapes(0).element_type();
switch (output_primitive_type) {
case F16:
return RunCudnnConvolutionImpl<Eigen::half>(params, scratch_allocator,
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
index 381aa37a1b..61aec1cecc 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
@@ -16,6 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_
+#include "absl/types/optional.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@@ -27,52 +30,8 @@ namespace gpu {
// This file contains low-level routines for running cudnn convolutions.
-// Different types of convolutions supported by cudnn.
-//
-// A way to think about these is that a convolution is defined by three arrays
-// -- the "input", the "filter", and the "output" -- and given any two of these,
-// we can compute the third. For example, a backward-input convolution takes as
-// input a filter and an "output" and produces an "input" such that if one were
-// to do a forward convolution of "input" using filter, the result would be
-// something with the same shape as "output".
-//
-// This way of thinking is not correct if you look at the values produced. For
-// example, a backward-input convolution is not actually the mathematical
-// inverse of a forward convolution. But it's right as far as the shapes and
-// "connectivity" (i.e. which elements of the input affect which elements of
-// the output) are concerned.
-enum class CudnnConvKind {
- kForward, // input + filter => output
- kBackwardInput, // filter + output => input
- kBackwardFilter, // input + output => filter
-};
-
-struct CudnnConvParams {
- CudnnConvKind kind;
- const Shape* input_shape;
- const Shape* filter_shape;
- const Shape* output_shape;
- se::DeviceMemoryBase input_buf;
- se::DeviceMemoryBase filter_buf;
- se::DeviceMemoryBase output_buf;
- const Window* window;
- const ConvolutionDimensionNumbers* dnums;
- int64 feature_group_count;
- se::dnn::AlgorithmConfig algorithm;
-};
-
-// Converts a CudnnConvKind value to a string.
-string CudnnConvKindToString(CudnnConvKind kind);
-
// Calls into cudnn to run the specified convolution.
//
-// Note that depending on the value of CudnnConvKind, the result of this call
-// may be written into input_buf, filter_buf, or output_buf!
-//
-// At the moment convolution with half data type is implemented with cudnn
-// PSEUDO_HALF configuration, that is, the input values are half and the
-// internal computation type is float.
-//
// We provide one overload which takes a scratch buffer, and another which takes
// an allocator which is responsible for allocating the scratch space. In
// theory the second one shouldn't be necessary -- users of this function could
@@ -83,11 +42,15 @@ string CudnnConvKindToString(CudnnConvKind kind);
// allocator and take note of how much memory is used. The next time you call
// the same conv, you can provide an explicitly preallocated scratch buffer of
// that size, if you like.
-Status RunCudnnConvolution(CudnnConvParams params,
+Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer,
se::DeviceMemoryBase scratch_buf, se::Stream* stream,
se::dnn::ProfileResult* profile_result = nullptr);
-Status RunCudnnConvolution(CudnnConvParams params,
+Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer,
se::ScratchAllocator* scratch_allocator,
se::Stream* stream,
se::dnn::ProfileResult* profile_result = nullptr);
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc
new file mode 100644
index 0000000000..3761c19cfc
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc
@@ -0,0 +1,278 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/pattern_matcher.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+// Describes a matched pattern:
+// max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias));
+// Where side_input has the shape of output buffer, and bias is a 1D array with
+// the dimension of number of output features.
+struct ConvWithRelu {
+ HloInstruction* maximum;
+ HloCustomCallInstruction* conv;
+ HloInstruction* bias;
+ HloInstruction* side_input;
+ HloConstantInstruction* alpha_conv;
+ HloConstantInstruction* alpha_side_input;
+};
+
+absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
+ using match::Add;
+ using match::AddAnyOrder;
+ using match::AnyOf;
+ using match::Broadcast;
+ using match::Constant;
+ using match::GetTupleElement;
+ using match::Maximum;
+ using match::MultiplyAnyOrder;
+ using match::Op;
+
+ // The pattern we want to match:
+ // max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias));
+ //
+ // With its variants involving commute/reassociation of adds, multiplies, and
+ // max, and omission of alpha1, side_input, alpha2, or bias.
+
+ HloInstruction* relu_input;
+
+ // Match max(0, relu_input).
+ auto zero_pattern = Broadcast(match::ConstantScalar(0));
+ if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) &&
+ !Match(instr, Maximum(Op(&relu_input), zero_pattern))) {
+ return absl::nullopt;
+ }
+ HloInstruction* conv_instr = nullptr;
+ HloInstruction* alpha_conv_instr = nullptr;
+ HloInstruction* alpha_side_input_instr = nullptr;
+ HloInstruction* bias_broadcast_instr = nullptr;
+ HloInstruction* bias = nullptr;
+ HloInstruction* side_input = nullptr;
+
+ // These nodes will not be in the returned value, but we need to check them
+ // for single use.
+ HloInstruction *gte = nullptr, *add1 = nullptr, *add2 = nullptr,
+ *mul1 = nullptr, *mul2 = nullptr;
+
+ const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias));
+ const auto conv_pattern = [&] {
+ auto alpha_pattern = Broadcast(Constant(&alpha_conv_instr));
+ auto conv_pattern = GetTupleElement(
+ &gte, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0);
+ return AnyOf<HloInstruction>(
+ MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern);
+ }();
+ const auto side_input_pattern = [&] {
+ auto alpha_pattern = Broadcast(Constant(&alpha_side_input_instr));
+ // If bias is already matched, match arbitrary additional input as side
+ // input. Note this may force a cheap operation (e.g. broadcast) to be
+ // materialized into a large buffer, as large as the output buffer.
+ //
+ // TODO(timshen): If in practice there are significant false positives, we
+ // should fix it.
+ auto side_input_pattern = Op(&side_input);
+ return AnyOf<HloInstruction>(
+ MultiplyAnyOrder(&mul2, alpha_pattern, side_input_pattern),
+ side_input_pattern);
+ }();
+
+ {
+ // Try to match any of the following form of add, in any association:
+ // addends[0]
+ // addends[0] + addends[1]
+ // addends[0] + addends[1] + addends[2]
+ //
+ // Then try to match each addend with one of the three patterns: bias, conv,
+ // or side_input. Notice that side_input matching must go last, as it
+ // also matches a conv or a bias.
+ HloInstruction* addends[3] = {nullptr, nullptr, nullptr};
+ auto add3_pattern = [&] {
+ auto add2_pattern = Add(&add1, Op(&addends[0]), Op(&addends[1]));
+ return AnyOf<HloInstruction>(
+ AddAnyOrder(&add2, add2_pattern, Op(&addends[2])), add2_pattern,
+ Op(&addends[0]));
+ }();
+ CHECK(Match(relu_input, add3_pattern));
+ for (auto addend : addends) {
+ if (addend) {
+ if (bias == nullptr && Match(addend, bias_pattern)) {
+ CHECK(bias);
+ } else if (conv_instr == nullptr && Match(addend, conv_pattern)) {
+ CHECK(conv_instr);
+ } else if (side_input == nullptr && Match(addend, side_input_pattern)) {
+ CHECK(side_input);
+ } else {
+ return absl::nullopt;
+ }
+ }
+ }
+ }
+
+ if (conv_instr == nullptr) {
+ return absl::nullopt;
+ }
+
+ for (HloInstruction* instr :
+ {conv_instr, bias_broadcast_instr, gte, add1, add2, mul1, mul2}) {
+ if (instr && instr->user_count() > 1) {
+ return absl::nullopt;
+ }
+ }
+
+ auto conv = Cast<HloCustomCallInstruction>(conv_instr);
+ auto bias_broadcast =
+ CastOrNull<HloBroadcastInstruction>(bias_broadcast_instr);
+
+ if (conv->custom_call_target() != kCudnnConvForwardCallTarget) {
+ return absl::nullopt;
+ }
+
+ if (bias_broadcast) {
+ // TODO(timshen): handle bias_broadcast_instr->dimensions() == {}.
+ if (bias_broadcast_instr->dimensions().size() != 1) {
+ return absl::nullopt;
+ }
+ if (bias_broadcast_instr->dimensions(0) !=
+ conv->convolution_dimension_numbers().output_feature_dimension()) {
+ return absl::nullopt;
+ }
+ }
+
+ return ConvWithRelu{
+ instr,
+ conv,
+ bias,
+ side_input,
+ CastOrNull<HloConstantInstruction>(alpha_conv_instr),
+ CastOrNull<HloConstantInstruction>(alpha_side_input_instr)};
+}
+
+StatusOr<std::unique_ptr<HloInstruction>> TryRewriteToCudnnForwardRelu(
+ ConvWithRelu match) {
+ auto conv = match.conv;
+
+ HloComputation* computation = conv->parent();
+ PrimitiveType element_type = conv->operand(0)->shape().element_type();
+
+ const auto get_alpha_value =
+ [](HloConstantInstruction* instr) -> StatusOr<double> {
+ TF_ASSIGN_OR_RETURN(
+ auto alpha,
+ Cast<HloConstantInstruction>(instr)->literal().Convert(F64));
+ return alpha.GetFirstElement<double>();
+ };
+
+ double alpha_conv = 1;
+ if (match.alpha_conv) {
+ TF_ASSIGN_OR_RETURN(alpha_conv, get_alpha_value(match.alpha_conv));
+ }
+
+ double alpha_side_input;
+ if (match.side_input) {
+ if (match.alpha_side_input) {
+ TF_ASSIGN_OR_RETURN(alpha_side_input,
+ get_alpha_value(match.alpha_side_input));
+ } else {
+ alpha_side_input = 1;
+ }
+ } else {
+ CHECK(match.alpha_side_input == nullptr);
+ alpha_side_input = 0;
+ }
+
+ auto bias = match.bias;
+ if (!bias) {
+ auto zero = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
+
+ int64 num_output_feature = conv->shape().tuple_shapes(0).dimensions(
+ conv->convolution_dimension_numbers().output_feature_dimension());
+ bias = computation->AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShapeWithDescendingLayout(element_type,
+ {num_output_feature}),
+ zero, {}));
+ }
+
+ CHECK(bias);
+ std::vector<HloInstruction*> args = {conv->mutable_operand(0),
+ conv->mutable_operand(1), bias};
+ if (match.side_input) {
+ args.push_back(match.side_input);
+ }
+ auto new_conv = computation->AddInstruction(HloInstruction::CreateCustomCall(
+ conv->shape(), args, kCudnnConvBiasActivationForwardCallTarget));
+ new_conv->set_window(conv->window());
+ new_conv->set_convolution_dimension_numbers(
+ conv->convolution_dimension_numbers());
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config,
+ conv->backend_config<CudnnConvBackendConfig>());
+ config.set_activation_mode(
+ static_cast<int64>(se::dnn::ActivationMode::kRelu));
+ config.set_conv_result_scale(alpha_conv);
+ config.set_side_input_scale(alpha_side_input);
+ TF_RETURN_IF_ERROR(new_conv->set_backend_config(config));
+
+ VLOG(1) << "Rewriting " << conv->name() << " to " << new_conv->name();
+ return HloInstruction::CreateGetTupleElement(conv->shape().tuple_shapes(0),
+ new_conv, 0);
+}
+
+} // namespace
+
+StatusOr<bool> CudnnFusedConvolutionRewriter::Run(HloModule* module) {
+ bool changed = false;
+ for (HloComputation* computation : module->MakeNonfusionComputations()) {
+ std::vector<ConvWithRelu> matches;
+ int num_forward_convs = 0;
+ for (auto instr : computation->instructions()) {
+ auto match = FindConvWithRelu(instr);
+ if (match.has_value()) {
+ matches.push_back(*match);
+ }
+ if (auto call = DynCast<HloCustomCallInstruction>(instr)) {
+ if (call->custom_call_target() == kCudnnConvForwardCallTarget) {
+ num_forward_convs++;
+ }
+ }
+ }
+ VLOG(1) << "Identified cuDNN forward conv + relu: " << matches.size()
+ << " out of " << num_forward_convs << " forward convs.";
+ std::vector<std::pair<HloInstruction*, std::unique_ptr<HloInstruction>>>
+ replacements;
+ for (const ConvWithRelu& match : matches) {
+ TF_ASSIGN_OR_RETURN(auto new_instr, TryRewriteToCudnnForwardRelu(match));
+ replacements.push_back({match.maximum, std::move(new_instr)});
+ changed = true;
+ }
+ for (auto& replacement : replacements) {
+ TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
+ replacement.first, std::move(replacement.second)));
+ }
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h
new file mode 100644
index 0000000000..bd12aadded
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_
+
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+class CudnnFusedConvolutionRewriter : public HloModulePass {
+ public:
+ absl::string_view name() const override {
+ return "cudnn-fused-convolution-rewriter";
+ }
+
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h
index 7e3f5775b8..f19996edfe 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h
@@ -32,7 +32,7 @@ namespace gpu {
// 2) The result of merging the fusion instruction into its users would not
// increase bytes transferred.
//
-class FusionMerger : public HloPassInterface {
+class FusionMerger : public HloModulePass {
public:
absl::string_view name() const override { return "fusion merger"; }
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
index 75f414e47f..79c74e7e8b 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
@@ -34,15 +34,6 @@ namespace xla {
namespace gpu {
-StatusOr<HloInstruction*> GpuCopyInsertion::FindOrInsertCopy(
- HloInstruction* hlo) {
- HloInstruction*& copy = hlo_to_copy_map_[hlo];
- if (copy == nullptr) {
- TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo));
- }
- return copy;
-}
-
StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
CopyInsertion generic_copy_insertion;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
index 8ffae18fe8..4c7e38ffeb 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
@@ -25,20 +25,11 @@ namespace gpu {
// Besides the modifications made by the generic xla::CopyInsertion, this
// GPU-specific copy insertion also materializes operands of library calls by
// inserting kCopy instructions.
-class GpuCopyInsertion : public HloPassInterface {
+class GpuCopyInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "copy-insertion"; }
StatusOr<bool> Run(HloModule* module) override;
-
- protected:
- // Returns a copy of `hlo`. Looks in hlo_to_copy_map_ first to avoid making
- // duplicate copies.
- StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo);
-
- // A map containing all copies inserted to materialize operands of library
- // calls. The key is the copied instruction and the value is the copy.
- tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> hlo_to_copy_map_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
index bbb3340760..9c64b4d10c 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
@@ -23,7 +23,7 @@ namespace xla {
// his pass should run early in the HLO pipeline and checks for HLO constructs
// which are not supported by the GPU backend and cannot be removed via HLO
// transformations (eg, sparse layouts).
-class GpuHloSupportChecker : public HloPassInterface {
+class GpuHloSupportChecker : public HloModulePass {
public:
GpuHloSupportChecker() = default;
~GpuHloSupportChecker() override = default;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
index d033faee8d..74352f26aa 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
@@ -21,8 +21,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_options.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -90,27 +92,33 @@ HeuristicLayoutAssignment(const HloInstruction* instr,
// operands and the output shape. Depending on the underlying algorithm, one of
// { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen.
Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
- HloInstruction* instr, LayoutConstraints* constraints) {
- CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString();
- Shape input_shape;
- Shape filter_shape;
- Shape output_shape;
- const auto& target = instr->custom_call_target();
- if (target == kCudnnConvForwardCallTarget) {
- input_shape = instr->operand(0)->shape();
- filter_shape = instr->operand(1)->shape();
- output_shape = instr->shape().tuple_shapes(0);
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- input_shape = instr->shape().tuple_shapes(0);
- filter_shape = instr->operand(1)->shape();
- output_shape = instr->operand(0)->shape();
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- input_shape = instr->operand(0)->shape();
- filter_shape = instr->shape().tuple_shapes(0);
- output_shape = instr->operand(1)->shape();
- } else {
- LOG(FATAL) << "Unexpected custom call target: "
- << instr->custom_call_target();
+ HloCustomCallInstruction* instr, LayoutConstraints* constraints) {
+ Shape lhs_shape = instr->operand(0)->shape();
+ Shape rhs_shape = instr->operand(1)->shape();
+ Shape result_shape = instr->shape().tuple_shapes(0);
+
+ Shape* input_shape;
+ Shape* filter_shape;
+ Shape* output_shape;
+
+ TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instr));
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ case CudnnConvKind::kForwardActivation:
+ input_shape = &lhs_shape;
+ filter_shape = &rhs_shape;
+ output_shape = &result_shape;
+ break;
+ case CudnnConvKind::kBackwardInput:
+ input_shape = &result_shape;
+ filter_shape = &rhs_shape;
+ output_shape = &lhs_shape;
+ break;
+ case CudnnConvKind::kBackwardFilter:
+ input_shape = &lhs_shape;
+ filter_shape = &result_shape;
+ output_shape = &rhs_shape;
+ break;
}
{
@@ -127,8 +135,9 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
}
TF_ASSIGN_OR_RETURN(
- std::tie(*input_shape.mutable_layout(), *filter_shape.mutable_layout(),
- *output_shape.mutable_layout()),
+ std::tie(*input_shape->mutable_layout(),
+ *filter_shape->mutable_layout(),
+ *output_shape->mutable_layout()),
StreamExecutorConvLayoutsToXlaLayouts(
instr->convolution_dimension_numbers(), input, filter, output));
}
@@ -141,24 +150,23 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
instr, /*index=*/{0}));
// Set layouts of the instructions' shapes.
- if (target == kCudnnConvForwardCallTarget) {
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0));
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1));
- TF_RETURN_IF_ERROR(
- constraints->SetBufferLayout(output_shape.layout(), *call_result_buf));
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 0));
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1));
- TF_RETURN_IF_ERROR(
- constraints->SetBufferLayout(input_shape.layout(), *call_result_buf));
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0));
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 1));
- TF_RETURN_IF_ERROR(
- constraints->SetBufferLayout(filter_shape.layout(), *call_result_buf));
- } else {
- LOG(FATAL) << "Unexpected custom call target: "
- << instr->custom_call_target();
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, instr, 0));
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, instr, 1));
+ TF_RETURN_IF_ERROR(
+ constraints->SetBufferLayout(result_shape.layout(), *call_result_buf));
+ // instr->operand(2), if exists, is the bias buffer. There is no need to
+ // assign layout to it, as it has only one dimension.
+
+ // instr->opernad(3), if exists, is the side input buffer.
+ if (instr->operand_count() == 4) {
+ if (kind != CudnnConvKind::kForwardActivation) {
+ return InternalError(
+ "Invalid convolution. Conv has a side input, but kind is not fused "
+ "conv forward: %s",
+ instr->ToString());
+ }
+ // The side input layout must match the output layout.
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(*output_shape, instr, 3));
}
return Status::OK();
}
@@ -173,8 +181,8 @@ Status GpuLayoutAssignment::AddBackendConstraints(
++iterator) {
HloInstruction* instruction = *iterator;
if (IsCustomCallToDnnConvolution(*instruction)) {
- TF_RETURN_IF_ERROR(
- AddBackendConstraintsToDnnConvCustomCall(instruction, constraints));
+ TF_RETURN_IF_ERROR(AddBackendConstraintsToDnnConvCustomCall(
+ Cast<HloCustomCallInstruction>(instruction), constraints));
}
// For batched dot we require the default layout.
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
index ce24af1cf8..e2b96a81d4 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_
#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/layout_assignment.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -47,7 +48,7 @@ class GpuLayoutAssignment : public LayoutAssignment {
private:
Status AddBackendConstraintsToDnnConvCustomCall(
- HloInstruction* instr, LayoutConstraints* constraints);
+ HloCustomCallInstruction* instr, LayoutConstraints* constraints);
se::StreamExecutor* stream_executor_;
};
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 22f43bc08b..ec3d8f9405 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -129,6 +129,8 @@ const char* const kCudnnConvBackwardInputCallTarget =
"__cudnn$convBackwardInput";
const char* const kCudnnConvBackwardFilterCallTarget =
"__cudnn$convBackwardFilter";
+const char* const kCudnnConvBiasActivationForwardCallTarget =
+ "__cudnn$convBiasActivationForward";
bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
if (hlo.opcode() != HloOpcode::kCustomCall) {
@@ -137,7 +139,8 @@ bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
const auto& target = hlo.custom_call_target();
return target == kCudnnConvForwardCallTarget ||
target == kCudnnConvBackwardInputCallTarget ||
- target == kCudnnConvBackwardFilterCallTarget;
+ target == kCudnnConvBackwardFilterCallTarget ||
+ target == kCudnnConvBiasActivationForwardCallTarget;
}
bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
@@ -145,59 +148,6 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
IsCustomCallToDnnConvolution(hlo);
}
-static HloInstruction* CreateCudnnConv(const char* call_target,
- const Shape& shape, HloInstruction* lhs,
- HloInstruction* rhs,
- const Window& window,
- const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count) {
- HloComputation* computation = lhs->parent();
-
- // This call returns a tuple of (conv_result, scratch_memory), where
- // conv_result is the actual result of the convolution, and scratch_memory is
- // temporary memory used by cudnn.
- //
- // At the moment, we don't know how much scratch memory this conv is going to
- // use, so we put u8[0] in this place. Later on another pass will choose
- // which conv algorithm to use, and at that point we'll modify the shape of
- // this second tuple element.
- Shape call_shape =
- ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})});
-
- HloInstruction* custom_call = computation->AddInstruction(
- HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target));
- custom_call->set_window(window);
- custom_call->set_convolution_dimension_numbers(dnums);
- custom_call->set_feature_group_count(feature_group_count);
- return custom_call;
-}
-
-HloInstruction* CreateCudnnConvForward(const Shape& shape,
- HloInstruction* input,
- HloInstruction* kernel,
- const Window& window,
- const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count) {
- return CreateCudnnConv(kCudnnConvForwardCallTarget, shape, input, kernel,
- window, dnums, feature_group_count);
-}
-
-HloInstruction* CreateCudnnConvBackwardInput(
- const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter,
- const Window& window, const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count) {
- return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, shape, output,
- reverse_filter, window, dnums, feature_group_count);
-}
-
-HloInstruction* CreateCudnnConvBackwardFilter(
- const Shape& shape, HloInstruction* input, HloInstruction* output,
- const Window& window, const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count) {
- return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, shape, input,
- output, window, dnums, feature_group_count);
-}
-
bool IsReductionToVector(const HloInstruction& reduce) {
if (HloOpcode::kReduce != reduce.opcode()) {
return false;
@@ -288,41 +238,35 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
value->getType());
}
-Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call,
- CudnnConvParams* params) {
- TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
- custom_call->backend_config<CudnnConvBackendConfig>());
- const auto& target = custom_call->custom_call_target();
- const auto& lhs_shape = custom_call->operand(0)->shape();
- const auto& rhs_shape = custom_call->operand(1)->shape();
- const auto& conv_result_shape = custom_call->shape().tuple_shapes(0);
-
- params->window = &custom_call->window();
- params->dnums = &custom_call->convolution_dimension_numbers();
- params->feature_group_count = custom_call->feature_group_count();
- params->algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc(
- backend_config.algorithm(), backend_config.tensor_ops_enabled()));
-
+StatusOr<CudnnConvKind> GetCudnnConvKind(
+ const HloCustomCallInstruction* instr) {
+ absl::string_view target = instr->custom_call_target();
if (target == kCudnnConvForwardCallTarget) {
- params->kind = CudnnConvKind::kForward;
- params->input_shape = &lhs_shape;
- params->filter_shape = &rhs_shape;
- params->output_shape = &conv_result_shape;
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- params->kind = CudnnConvKind::kBackwardInput;
- params->input_shape = &conv_result_shape;
- params->filter_shape = &rhs_shape;
- params->output_shape = &lhs_shape;
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- params->kind = CudnnConvKind::kBackwardFilter;
- params->input_shape = &lhs_shape;
- params->filter_shape = &conv_result_shape;
- params->output_shape = &rhs_shape;
- } else {
- LOG(FATAL) << "Unexpected custom call target: "
- << custom_call->custom_call_target();
+ return CudnnConvKind::kForward;
+ }
+ if (target == kCudnnConvBackwardInputCallTarget) {
+ return CudnnConvKind::kBackwardInput;
+ }
+ if (target == kCudnnConvBackwardFilterCallTarget) {
+ return CudnnConvKind::kBackwardFilter;
+ }
+ if (target == kCudnnConvBiasActivationForwardCallTarget) {
+ return CudnnConvKind::kForwardActivation;
+ }
+ return InternalError("Unexpected call target: %s", target);
+}
+
+string CudnnConvKindToString(CudnnConvKind kind) {
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ return "forward";
+ case CudnnConvKind::kBackwardFilter:
+ return "backward_filter";
+ case CudnnConvKind::kBackwardInput:
+ return "backward_input";
+ case CudnnConvKind::kForwardActivation:
+ return "forward with activation";
}
- return Status::OK();
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index 09c455cc1e..a64a616ab1 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -20,7 +20,6 @@ limitations under the License.
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
-#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
@@ -30,6 +29,33 @@ limitations under the License.
namespace xla {
namespace gpu {
+// Different types of convolutions supported by cudnn.
+//
+// A way to think about these is that a convolution is defined by three arrays
+// -- the "input", the "filter", and the "output" -- and given any two of these,
+// we can compute the third. For example, a backward-input convolution takes as
+// input a filter and an "output" and produces an "input" such that if one were
+// to do a forward convolution of "input" using filter, the result would be
+// something with the same shape as "output".
+//
+// This way of thinking is not correct if you look at the values produced. For
+// example, a backward-input convolution is not actually the mathematical
+// inverse of a forward convolution. But it's right as far as the shapes and
+// "connectivity" (i.e. which elements of the input affect which elements of
+// the output) are concerned.
+enum class CudnnConvKind {
+ kForward, // input + filter => output
+ kBackwardInput, // filter + output => input
+ kBackwardFilter, // input + output => filter
+ kForwardActivation, // activation(conv(input, filter) + broadcast(bias) +
+ // (optionally) side_input) => output
+};
+
+StatusOr<CudnnConvKind> GetCudnnConvKind(const HloCustomCallInstruction* instr);
+
+// Converts a CudnnConvKind value to a string.
+string CudnnConvKindToString(CudnnConvKind kind);
+
constexpr int64 kWarpSize = 32;
// Returns true if `hlo` will be implemented as a call to BLAS gemm.
@@ -95,6 +121,7 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo);
extern const char* const kCudnnConvForwardCallTarget;
extern const char* const kCudnnConvBackwardInputCallTarget;
extern const char* const kCudnnConvBackwardFilterCallTarget;
+extern const char* const kCudnnConvBiasActivationForwardCallTarget;
// Returns true if `hlo` will be implemented as a call to a cuDNN convolution
// routine.
@@ -104,28 +131,6 @@ extern const char* const kCudnnConvBackwardFilterCallTarget;
// kConvolution opcode.
bool IsCustomCallToDnnConvolution(const HloInstruction& hlo);
-// Creates a CustomCall for a cudnn forward/backward-input/backward-filter conv.
-// Note that these CustomCalls return a tuple (conv_result, scratch_memory). If
-// you want just the conv result, you'll need to get-tuple-element the value
-// returned by this function.
-//
-// The created cudnn call will use the default cudnn algorithm and no scratch
-// space.
-HloInstruction* CreateCudnnConvForward(const Shape& shape,
- HloInstruction* input,
- HloInstruction* kernel,
- const Window& window,
- const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count);
-HloInstruction* CreateCudnnConvBackwardInput(
- const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter,
- const Window& window, const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count);
-HloInstruction* CreateCudnnConvBackwardFilter(
- const Shape& shape, HloInstruction* input, HloInstruction* output,
- const Window& window, const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count);
-
// Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm
// or cuDNN convolution.
bool ImplementedAsLibraryCall(const HloInstruction& hlo);
@@ -150,11 +155,6 @@ llvm::Value* EmitPrintf(absl::string_view fmt,
llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
llvm::IRBuilder<>* builder);
-// Populates params using conv, which must be a custom-call to a cudnn
-// convolution. Does not modify any buffers in the params.
-Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call,
- CudnnConvParams* params);
-
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index b669881026..c792dd2ddb 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -465,35 +465,18 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
if (IsCustomCallToDnnConvolution(*custom_call)) {
const auto& assn = ir_emitter_context_->buffer_assignment();
- auto lhs_slice = GetAllocationSlice(*custom_call->operand(0));
- auto rhs_slice = GetAllocationSlice(*custom_call->operand(1));
+ std::vector<BufferAllocation::Slice> operand_slices;
+ operand_slices.reserve(custom_call->operand_count());
+ for (const auto* operand : custom_call->operands()) {
+ operand_slices.push_back(GetAllocationSlice(*operand));
+ }
auto tuple_result_slice = GetAllocationSlice(*custom_call);
auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
- const auto& target = custom_call->custom_call_target();
- BufferAllocation::Slice input_slice, filter_slice, output_slice;
-
- if (target == kCudnnConvForwardCallTarget) {
- input_slice = lhs_slice;
- filter_slice = rhs_slice;
- output_slice = conv_result_slice;
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- input_slice = conv_result_slice;
- filter_slice = rhs_slice;
- output_slice = lhs_slice;
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- input_slice = lhs_slice;
- filter_slice = conv_result_slice;
- output_slice = rhs_slice;
- } else {
- LOG(FATAL) << "Unexpected custom call target: "
- << custom_call->custom_call_target();
- }
-
thunk_sequence_->emplace_back(absl::make_unique<ConvolutionThunk>(
- Cast<HloCustomCallInstruction>(custom_call), input_slice, filter_slice,
- output_slice, scratch_slice, tuple_result_slice));
+ Cast<HloCustomCallInstruction>(custom_call), std::move(operand_slices),
+ conv_result_slice, scratch_slice, tuple_result_slice));
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index dfdcf1875d..0b3b429710 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
+#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
@@ -208,6 +209,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
pipeline.AddPass<CudnnConvolutionRewriter>();
+ pipeline.AddPass<CudnnFusedConvolutionRewriter>();
pipeline.AddPass<PadInsertion>();
if (IsVoltaOrLater(*stream_exec)) {
pipeline.AddPass<PadForTensorCores>();
@@ -402,7 +404,7 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) {
LOG(WARNING)
<< "*** WARNING *** You are using ptxas " << vmaj << "." << vmin << "."
<< vdot
- << ", which older than 9.2.88. ptxas 9.x before 9.2.88 is known to "
+ << ", which is older than 9.2.88. ptxas 9.x before 9.2.88 is known to "
"miscompile XLA code, leading to incorrect results or "
"invalid-address errors.\n\nYou do not need to update to CUDA "
"9.2.88; cherry-picking the ptxas binary is sufficient.";
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
index b0061fa655..e3869b5c36 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
@@ -36,15 +37,32 @@ static constexpr int64 kDesiredNumFeaturesFactor = 8;
// there's additional room for speedups. Achieving those speedups without also
// slowing other things down will likely require a more sophisticated heuristic,
// possibly some form of auto-tuning.
-static constexpr double kMaxBytesTouchedIncrease = 1.2;
+//
+// This value should be >= 4/3, otherwise the "dims of size 3 padded up to 4"
+// special case inside PadShape won't fire.
+static constexpr double kMaxBytesTouchedIncrease = 1.35;
// Pads the given dimensions in the given shape up to a multiple of
// kDesiredNumFeaturesFactor.
static Shape PadShape(Shape s, absl::Span<const int64> dims) {
for (int64 dim : dims) {
int64 dim_to_pad_size = s.dimensions(dim);
- int64 new_dim_to_pad_size =
- RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor);
+
+ // Round dim_to_pad_size up to the next multiple of
+ // kDesiredNumFeaturesFactor.
+ //
+ // Special case: dims of size 3 are rounded up to 4, not
+ // kDesiredNumFeaturesFactor. Empirically (and on the advice of nvidia),
+ // this helps, but as of writing, it's not supported by anything in the
+ // cudnn docs.
+ int64 new_dim_to_pad_size;
+ if (dim_to_pad_size == 3) {
+ new_dim_to_pad_size = 4;
+ } else {
+ new_dim_to_pad_size =
+ RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor);
+ }
+
s.set_dimensions(dim, new_dim_to_pad_size);
}
return s;
@@ -209,7 +227,11 @@ static std::vector<HloInstruction*> GetRelevantConvs(HloComputation* comp) {
std::vector<HloInstruction*> convs;
for (HloInstruction* instr : comp->instructions()) {
if (IsCustomCallToDnnConvolution(*instr) &&
- instr->operand(0)->shape().element_type() == F16) {
+ instr->operand(0)->shape().element_type() == F16 &&
+ // TODO(timshen): Disable for fused conv for now. Implement it if it's
+ // needed.
+ Cast<HloCustomCallInstruction>(instr)->custom_call_target() !=
+ kCudnnConvBiasActivationForwardCallTarget) {
convs.push_back(instr);
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
index 11dc56a64f..e592a3774e 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
@@ -30,7 +30,7 @@ namespace gpu {
// targeting before running this pass.
//
// TODO(jlebar): Also pad dots.
-class PadForTensorCores : public HloPassInterface {
+class PadForTensorCores : public HloModulePass {
public:
absl::string_view name() const override { return "pad for tensor cores"; }
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index 2a6415d0b6..b42a19e3a2 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -30,7 +30,8 @@ namespace gpu {
namespace {
bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
- CHECK_EQ(conv.custom_call_target(), kCudnnConvForwardCallTarget);
+ CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget ||
+ conv.custom_call_target() == kCudnnConvBiasActivationForwardCallTarget);
return window_util::HasSymmetricPadding(conv.window()) &&
!window_util::HasNegativePadding(conv.window()) &&
!window_util::HasDilation(conv.window());
@@ -161,12 +162,14 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) {
// The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract
// out the shape of conv_result.
- Shape old_conv_shape = conv->shape().tuple_shapes(0);
-
VLOG(1) << "Canonicalizing forward conv";
- auto new_conv = CreateCudnnConvForward(
- old_conv_shape, new_input, new_kernel, new_conv_window,
- conv->convolution_dimension_numbers(), conv->feature_group_count());
+ std::vector<HloInstruction*> operands(conv->operands().begin(),
+ conv->operands().end());
+ operands[0] = new_input;
+ operands[1] = new_kernel;
+ auto new_conv = conv->parent()->AddInstruction(
+ conv->CloneWithNewOperands(conv->shape(), operands));
+ new_conv->set_window(new_conv_window);
VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n "
<< new_conv->ToString();
TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
@@ -242,10 +245,10 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
// The shape of the backward_conv CustomCall is a tuple (conv_result,
// scratch_buffer). Extract out the shape of conv_result.
- Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
- HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter(
- backward_conv_shape, padded_input, output, new_backward_conv_window,
- backward_conv_dnums, backward_conv->feature_group_count());
+ HloInstruction* new_backward_conv =
+ computation->AddInstruction(backward_conv->CloneWithNewOperands(
+ backward_conv->shape(), {padded_input, output}));
+ new_backward_conv->set_window(new_backward_conv_window);
VLOG(1) << "Canonicalizing backward filter conv";
VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
@@ -308,9 +311,12 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
HloInstruction* output = backward_conv->mutable_operand(0);
HloInstruction* filter = backward_conv->mutable_operand(1);
- HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput(
- new_backward_conv_shape, output, filter, new_backward_conv_window,
- backward_conv_dnums, backward_conv->feature_group_count());
+ HloInstruction* new_backward_conv_call =
+ computation->AddInstruction(backward_conv->CloneWithNewOperands(
+ ShapeUtil::MakeTupleShape(
+ {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}),
+ {output, filter}));
+ new_backward_conv_call->set_window(new_backward_conv_window);
// The CustomCall created above returns a tuple (conv_result, scratch_memory).
// Extract out the two elements.
@@ -380,7 +386,8 @@ StatusOr<bool> PadInsertion::RunOnComputation(HloComputation* computation) {
}
for (HloInstruction* instruction : convs) {
const auto& target = instruction->custom_call_target();
- if (target == kCudnnConvForwardCallTarget) {
+ if (target == kCudnnConvForwardCallTarget ||
+ target == kCudnnConvBiasActivationForwardCallTarget) {
changed |= CanonicalizeForwardConvolution(instruction);
} else if (target == kCudnnConvBackwardFilterCallTarget) {
changed |= CanonicalizeBackwardFilterConvolution(instruction);
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h
index a622e894ed..25cdf64c4c 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h
@@ -24,7 +24,7 @@ namespace gpu {
// An HLO pass that canonicalizes convolution instructions for GPU codegen. It
// inserts Pad instructions before Convolution instructions with uncanonicalized
// padding, so that they can be lowered to cuDNN convolution.
-class PadInsertion : public HloPassInterface {
+class PadInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "pad insertion"; }
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD
index db4a33dc56..a725533567 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD
@@ -25,15 +25,17 @@ filegroup(
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
cc_library(
name = "gpu_codegen_test",
testonly = True,
srcs = ["gpu_codegen_test.cc"],
hdrs = ["gpu_codegen_test.h"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:gpu_plugin",
@@ -48,9 +50,7 @@ cc_library(
tf_cc_test(
name = "gpu_copy_test",
srcs = ["gpu_copy_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:literal",
@@ -67,9 +67,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_ftz_test",
srcs = ["gpu_ftz_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/core:test_main",
@@ -79,9 +77,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_index_test",
srcs = ["gpu_index_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:literal",
@@ -102,9 +98,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_infeed_test",
srcs = ["infeed_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:literal",
@@ -125,9 +119,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_kernel_tiling_test",
srcs = ["gpu_kernel_tiling_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:hlo",
@@ -142,7 +134,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_ldg_test",
srcs = ["gpu_ldg_test.cc"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:literal",
@@ -159,9 +151,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_noalias_test",
srcs = ["gpu_noalias_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:literal",
@@ -178,9 +168,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_fusion_test",
srcs = ["gpu_fusion_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:hlo_module_config",
@@ -194,9 +182,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_unrolling_test",
srcs = ["gpu_unrolling_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:hlo_module_config",
@@ -211,9 +197,7 @@ tf_cc_test(
name = "gpu_alignment_test",
testonly = True,
srcs = ["gpu_alignment_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:gpu_plugin",
@@ -225,3 +209,17 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
+
+tf_cc_test(
+ name = "cudnn_fused_convolution_rewriter_test",
+ srcs = ["cudnn_fused_convolution_rewriter_test.cc"],
+ tags = tf_cuda_tests_tags(),
+ deps = [
+ ":gpu_codegen_test",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc
new file mode 100644
index 0000000000..5632cac186
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc
@@ -0,0 +1,283 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/strings/str_replace.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class CudnnFusedConvolutionRewriterTest : public HloTestBase {
+ protected:
+ string GetOptimizedHlo(absl::string_view hlo_string) {
+ return backend()
+ .compiler()
+ ->RunHloPasses(ParseHloString(hlo_string, GetModuleConfigForTest())
+ .ConsumeValueOrDie(),
+ backend().default_stream_executor(),
+ backend().memory_allocator())
+ .ConsumeValueOrDie()
+ ->ToString();
+ }
+
+ void TestMatchWithAllTypes(absl::string_view hlo_string) {
+ for (absl::string_view type : {"f16", "f32", "f64"}) {
+ const string hlo_with_new_type =
+ absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
+ const string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
+ EXPECT_EQ(absl::string_view::npos,
+ optimized_hlo_string.find("__cudnn$convForward"))
+ << optimized_hlo_string;
+ EXPECT_NE(absl::string_view::npos,
+ optimized_hlo_string.find("__cudnn$convBiasActivationForward"))
+ << optimized_hlo_string;
+ EXPECT_TRUE(RunAndCompare(hlo_with_new_type, ErrorSpec{0.01}))
+ << optimized_hlo_string;
+ }
+ }
+
+ void TestNotMatchWithAllTypes(absl::string_view hlo_string) {
+ for (absl::string_view type : {"f16", "f32", "f64"}) {
+ const string hlo_with_new_type =
+ absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
+ string optimized_hlo = GetOptimizedHlo(hlo_with_new_type);
+ EXPECT_NE(absl::string_view::npos,
+ optimized_hlo.find("__cudnn$convForward"))
+ << optimized_hlo;
+ EXPECT_EQ(absl::string_view::npos,
+ optimized_hlo.find("__cudnn$convBiasActivationForward"))
+ << optimized_hlo;
+ }
+ }
+};
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestConvOnly) {
+ // max(0, conv(x, w));
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
+
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,17,32] parameter(1)
+
+ conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ ROOT relu = TYPE[1,32,9,9] maximum(zeros, conv)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestBias) {
+ // max(0, conv(x, w) + bias);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ bias = TYPE[64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestSideInputOnly) {
+ // max(0, conv(x, w) + side_input);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ add1 = TYPE[1,3,3,64] add(conv, side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestBiasAndSideInput) {
+ // max(0, conv(x, w) + side_input + bias);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+ bias = TYPE[64] parameter(3)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+ add2 = TYPE[1,3,3,64] add(add1, side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConv) {
+ // max(0, 0.999994934 * conv(x, w));
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
+ alpha_conv_scalar = TYPE[] constant(0.999994934)
+
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,17,32] parameter(1)
+
+ conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ alpha_conv = TYPE[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={}
+ scaled_conv = TYPE[1,32,9,9] multiply(conv, alpha_conv)
+ ROOT relu = TYPE[1,32,9,9] maximum(zeros, scaled_conv)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndSideInput) {
+ // max(0, conv(x, w) + 0.899994934 * side_input);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_side_input_scalar = TYPE[] constant(0.899994934)
+ alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+ add1 = TYPE[1,3,3,64] add(conv, scaled_side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndScaledSideInput) {
+ // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_conv_scalar = TYPE[] constant(0.999994934)
+ alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
+ alpha_side_input_scalar = TYPE[] constant(0.899994934)
+ alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
+ scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+ add1 = TYPE[1,3,3,64] add(scaled_conv, scaled_side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest,
+ TestScaledConvAndScaledSideInputWithBias) {
+ // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input + bias);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_conv_scalar = TYPE[] constant(0.999994934)
+ alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
+ alpha_side_input_scalar = TYPE[] constant(0.899994934)
+ alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+ bias = TYPE[64] parameter(3)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
+ scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = TYPE[1,3,3,64] add(scaled_conv, broadcasted_bias)
+ add2 = TYPE[1,3,3,64] add(add1, scaled_side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchMaxZeroOnly) {
+ // max(0.1, conv(x, w)) shouldn't match.
+ TestNotMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ point_one = TYPE[] constant(0.1)
+ point_ones = TYPE[1,32,9,9] broadcast(point_one), dimensions={}
+
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,17,32] parameter(1)
+
+ conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ ROOT relu = TYPE[1,32,9,9] maximum(point_ones, conv)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchBroadcastedBiasOnly) {
+ // max(0, conv(x, w) + side_input1 + side_input2) shouldn't match.
+ TestNotMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input1 = TYPE[1,3,3,64] parameter(2)
+ side_input2 = TYPE[1,3,3,64] parameter(3)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ add1 = TYPE[1,3,3,64] add(conv, side_input2)
+ add2 = TYPE[1,3,3,64] add(add1, side_input1)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+ })");
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index e0f3a7e0e2..2bd04259c0 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -736,4 +736,209 @@ HeapSimulator::Result LazyBestFitHeap::Finish() {
return result_;
}
+void GlobalDecreasingSizeBestFitHeap::Alloc(const BufferValue* buffer,
+ int64 size) {
+ // Degenerate case: 0-sized buffers are always allocated at offset 0.
+ if (size == 0) {
+ result_.chunk_map.emplace(buffer, Chunk{0, 0});
+ return;
+ }
+ auto emplace_result = buffer_intervals_.emplace(
+ buffer, BufferInterval{buffer, size, current_time_, -1});
+ DCHECK(emplace_result.second);
+ ++current_time_;
+}
+
+void GlobalDecreasingSizeBestFitHeap::Free(const BufferValue* buffer,
+ int64 size) {
+ // Degenerate case: 0-sized buffers are always allocated at offset 0.
+ if (size == 0) {
+ return;
+ }
+ BufferInterval& buffer_interval = FindOrDie(buffer_intervals_, buffer);
+ DCHECK_EQ(buffer_interval.buffer, buffer);
+ DCHECK_EQ(buffer_interval.size, size);
+ DCHECK_EQ(buffer_interval.end, -1);
+ buffer_interval.end = current_time_;
+ ++current_time_;
+}
+
+namespace {
+
+// Node in BufferIntervalTree that stores the alloc and free times of a buffer,
+// and the chunk assigned to it.
+struct BufferIntervalTreeNode {
+ // Alloc time.
+ int64 start;
+ // Free time.
+ int64 end;
+ // Maximum free time of all nodes in the subtree where this node is the root.
+ int64 subtree_end;
+ // Allocated chunk for the buffer.
+ HeapSimulator::Chunk chunk;
+ // Left child.
+ BufferIntervalTreeNode* left;
+ // Right child.
+ BufferIntervalTreeNode* right;
+};
+
+// An interval tree that can query buffers overlapping in time.
+class BufferIntervalTree {
+ public:
+ explicit BufferIntervalTree(int capacity) : node_storage_(capacity) {}
+
+ using Chunk = HeapSimulator::Chunk;
+
+ // Adds a buffer to the interval tree, with the time interval and allocated
+ // chunk specified.
+ void Add(int64 start, int64 end, const Chunk& chunk) {
+ int index = node_count_;
+ DCHECK_LT(index, node_storage_.size());
+ ++node_count_;
+
+ node_storage_[index] =
+ BufferIntervalTreeNode{start, end, end, chunk, nullptr, nullptr};
+
+ if (index == 0) {
+ // This is root.
+ return;
+ }
+
+ BufferIntervalTreeNode* parent = &node_storage_[0];
+ while (true) {
+ parent->subtree_end = std::max(parent->subtree_end, end);
+ if (parent->start > start) {
+ if (parent->left == nullptr) {
+ parent->left = &node_storage_[index];
+ return;
+ }
+ parent = parent->left;
+ } else {
+ if (parent->right == nullptr) {
+ parent->right = &node_storage_[index];
+ return;
+ }
+ parent = parent->right;
+ }
+ }
+ }
+
+ // Returns vector of allocated chunks that overlap with the given time
+ // interval.
+ std::vector<Chunk> ChunksOverlappingInTime(int64 start, int64 end) {
+ std::vector<Chunk> result;
+ if (node_count_ == 0) {
+ return result;
+ }
+ std::vector<BufferIntervalTreeNode*> visiting_stack;
+ visiting_stack.push_back(&node_storage_[0]);
+ while (!visiting_stack.empty()) {
+ BufferIntervalTreeNode* top = visiting_stack.back();
+ visiting_stack.pop_back();
+ if (start > top->subtree_end) {
+ continue;
+ }
+ if (top->left != nullptr) {
+ visiting_stack.push_back(top->left);
+ }
+ if (top->start <= end && top->end >= start) {
+ result.push_back(top->chunk);
+ }
+ if (end < top->start) {
+ continue;
+ }
+ if (top->right != nullptr) {
+ visiting_stack.push_back(top->right);
+ }
+ }
+ return result;
+ }
+
+ private:
+ int64 node_count_ = 0;
+ std::vector<BufferIntervalTreeNode> node_storage_;
+};
+
+} // namespace
+
+HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() {
+ std::vector<BufferInterval> sorted_buffer_intervals;
+ for (auto& entry : buffer_intervals_) {
+ sorted_buffer_intervals.push_back(entry.second);
+ }
+ std::sort(sorted_buffer_intervals.begin(), sorted_buffer_intervals.end(),
+ [](const BufferInterval& x, const BufferInterval& y) {
+ if (x.size != y.size) {
+ return x.size > y.size;
+ }
+ if (x.end - x.start != y.end - y.start) {
+ return x.end - x.start > y.end - y.start;
+ }
+ return x.buffer->id() < y.buffer->id();
+ });
+
+ BufferIntervalTree interval_tree(sorted_buffer_intervals.size());
+ for (auto& buffer_interval : sorted_buffer_intervals) {
+ auto chunks_overlapping_in_time = interval_tree.ChunksOverlappingInTime(
+ buffer_interval.start, buffer_interval.end);
+ std::sort(
+ chunks_overlapping_in_time.begin(), chunks_overlapping_in_time.end(),
+ [](const Chunk& x, const Chunk& y) { return x.offset < y.offset; });
+
+ // Find the minimum free chunk that can hold this buffer.
+ Chunk min_fit_chunk{-1, INT64_MAX};
+ auto use_free_chunk_if_smaller = [&](int64 free_offset, int64 free_size) {
+ if (free_size < buffer_interval.size) {
+ return;
+ }
+
+ if (free_size < min_fit_chunk.size) {
+ min_fit_chunk = {free_offset, free_size};
+ }
+ };
+
+ int64 offset = 0;
+ for (auto& chunk : chunks_overlapping_in_time) {
+ if (offset < chunk.offset) {
+ use_free_chunk_if_smaller(offset, chunk.offset - offset);
+ }
+ offset =
+ std::max(offset, RoundUpToNearest(chunk.chunk_end(), alignment_));
+ }
+ use_free_chunk_if_smaller(offset, result_.heap_size - offset);
+
+ if (min_fit_chunk.offset == -1) {
+ // Increase the heap size to fit in the last free chunk.
+ result_.heap_size = offset + buffer_interval.size;
+ min_fit_chunk = {offset, buffer_interval.size};
+ }
+
+ min_fit_chunk.size = buffer_interval.size;
+ const auto emplace_result =
+ result_.chunk_map.emplace(buffer_interval.buffer, min_fit_chunk);
+ DCHECK(emplace_result.second);
+
+ interval_tree.Add(buffer_interval.start, buffer_interval.end,
+ min_fit_chunk);
+ }
+ return result_;
+}
+
+HeapSimulator::Result ChooseBestHeapAlgorithm::Finish() {
+ DCHECK(!algorithms_.empty());
+ std::vector<Result> results(algorithms_.size());
+ int64 min_size = INT64_MAX;
+ int min_size_index = -1;
+ for (int i = 0; i < algorithms_.size(); ++i) {
+ results[i] = algorithms_[i]->Finish();
+ if (results[i].heap_size < min_size) {
+ min_size = results[i].heap_size;
+ min_size_index = i;
+ }
+ }
+
+ DCHECK_GE(min_size_index, 0);
+ return results[min_size_index];
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h
index ffbf947d5a..7d6dcc0dc9 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.h
+++ b/tensorflow/compiler/xla/service/heap_simulator.h
@@ -351,6 +351,68 @@ class LazyBestFitHeap : public HeapAlgorithm {
std::set<Chunk, OrderChunkByIncreasingSize> free_;
};
+// GlobalDecreasingSizeBestFitHeap collects the live intervals of all buffers,
+// then allocates them in decreasing sizes regardless of the alloc/free time. It
+// internally tracks the allocated buffers and their live intervals; when
+// allocating a buffer, it finds the best-fit free chunk during its live
+// interval.
+class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
+ public:
+ GlobalDecreasingSizeBestFitHeap(int64 alignment) : alignment_(alignment) {}
+ ~GlobalDecreasingSizeBestFitHeap() override {}
+
+ void Alloc(const BufferValue* buffer, int64 size) override;
+ void Free(const BufferValue* buffer, int64 size) override;
+ Result Finish() override;
+
+ private:
+ int64 alignment_;
+ Result result_;
+
+ // The current time represented as an integer. It increments by 1 at each
+ // Alloc or Free call.
+ int64 current_time_ = 0;
+
+ // BufferInterval stores a buffer's size and time interval.
+ struct BufferInterval {
+ const BufferValue* buffer;
+ int64 size;
+ // Alloc time of the buffer.
+ int64 start;
+ // Free time of the buffer.
+ int64 end;
+ };
+ tensorflow::gtl::FlatMap<const BufferValue*, BufferInterval>
+ buffer_intervals_;
+};
+
+// A heap algorithm that chooses the best results from other algorithms added to
+// it.
+class ChooseBestHeapAlgorithm : public HeapAlgorithm {
+ public:
+ ChooseBestHeapAlgorithm(
+ std::unique_ptr<std::vector<std::unique_ptr<HeapAlgorithm>>> algorithms)
+ : algorithms_(std::move(*algorithms)) {}
+ ~ChooseBestHeapAlgorithm() override {}
+
+ void Alloc(const BufferValue* buffer, int64 size) override {
+ for (auto& algorithm : algorithms_) {
+ algorithm->Alloc(buffer, size);
+ }
+ }
+
+ void Free(const BufferValue* buffer, int64 size) override {
+ for (auto& algorithm : algorithms_) {
+ algorithm->Free(buffer, size);
+ }
+ }
+
+ Result Finish() override;
+
+ private:
+ std::vector<std::unique_ptr<HeapAlgorithm>> algorithms_;
+};
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HEAP_SIMULATOR_H_
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 957c4a6891..191fbf8194 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -1021,5 +1021,135 @@ TEST_F(LazyBestFitHeapTest, Alignment) {
EXPECT_EQ(128, result.chunk_map.at(buffer_e_).offset);
}
+class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {};
+
+TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) {
+ GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
+ const HeapSimulator::Result result = heap.Finish();
+ EXPECT_EQ(0, result.heap_size);
+ EXPECT_EQ(0, result.chunk_map.size());
+}
+
+TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) {
+ // space
+ // ^
+ // | +---a---+
+ // | +-------+
+ // | +---c---+
+ // | +-------+
+ // | | b |
+ // | +-------+
+ // | +-------+
+ // | | |
+ // | | d |
+ // | +-------+
+ // -----------------> time
+ GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
+ heap.Alloc(buffer_a_, 10);
+ heap.Alloc(buffer_b_, 30);
+ heap.Alloc(buffer_c_, 20);
+ heap.Alloc(buffer_d_, 40);
+ heap.Free(buffer_a_, 10);
+ heap.Free(buffer_b_, 30);
+ heap.Free(buffer_c_, 20);
+ heap.Free(buffer_d_, 40);
+
+ const HeapSimulator::Result result = heap.Finish();
+ EXPECT_EQ(100, result.heap_size);
+ EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
+ EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
+ EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size);
+ EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size);
+
+ EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset);
+ EXPECT_EQ(40, result.chunk_map.at(buffer_b_).offset);
+ EXPECT_EQ(70, result.chunk_map.at(buffer_c_).offset);
+ EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
+}
+
+TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) {
+ // space
+ // ^
+ // | +-------+
+ // | +---b---+
+ // | +-------+
+ // | | |
+ // | | d |
+ // | +---a---+ +-------+
+ // |
+ // | +-------+
+ // | | |
+ // | | c |
+ // | | |
+ // | +-------+
+ // ---------------------> time
+ GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/20);
+ heap.Alloc(buffer_a_, 10);
+ heap.Alloc(buffer_b_, 20);
+ heap.Alloc(buffer_c_, 50);
+ heap.Free(buffer_a_, 10);
+ heap.Alloc(buffer_d_, 40);
+ heap.Free(buffer_b_, 20);
+ heap.Free(buffer_c_, 50);
+ heap.Free(buffer_d_, 40);
+
+ const HeapSimulator::Result result = heap.Finish();
+ EXPECT_EQ(120, result.heap_size);
+ EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
+ EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
+ EXPECT_EQ(50, result.chunk_map.at(buffer_c_).size);
+ EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size);
+
+ EXPECT_EQ(60, result.chunk_map.at(buffer_a_).offset);
+ EXPECT_EQ(100, result.chunk_map.at(buffer_b_).offset);
+ EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
+ EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset);
+}
+
+TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) {
+ // space
+ // ^
+ // | +-------+
+ // | +---b---+
+ // | +-------+
+ // | | d |
+ // | +--a--+ +-------+
+ // | +-------+
+ // | | |
+ // | | c |
+ // | +-------+
+ // | +-------+
+ // | | |
+ // | | e |
+ // | | |
+ // | +-------+
+ // ---------------------> time
+ GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
+ heap.Alloc(buffer_a_, 10);
+ heap.Alloc(buffer_b_, 20);
+ heap.Alloc(buffer_c_, 40);
+ heap.Free(buffer_a_, 10);
+ heap.Alloc(buffer_d_, 30);
+ heap.Alloc(buffer_e_, 50);
+ heap.Free(buffer_b_, 20);
+ heap.Free(buffer_c_, 40);
+ heap.Free(buffer_d_, 30);
+ heap.Free(buffer_e_, 50);
+
+ const HeapSimulator::Result result = heap.Finish();
+ EXPECT_EQ(140, result.heap_size);
+ EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
+ EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
+ EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size);
+ EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size);
+ EXPECT_EQ(50, result.chunk_map.at(buffer_e_).size);
+
+ EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset);
+ EXPECT_EQ(120, result.chunk_map.at(buffer_b_).offset);
+ EXPECT_EQ(50, result.chunk_map.at(buffer_c_).offset);
+ EXPECT_EQ(90, result.chunk_map.at(buffer_d_).offset);
+ EXPECT_EQ(0, result.chunk_map.at(buffer_e_).offset);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index b19ec12638..caaca16f71 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
-// Next ID: 53
+// Next ID: 54
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@@ -124,9 +124,13 @@ message HloInstructionProto {
// The string representation of the infeed configuration.
bytes infeed_config = 27;
- // Name of a global symbol to call, only present for kCustomCall.
+ // Name of a external target (eg, global symbol) to call, only present for
+ // kCustomCall.
string custom_call_target = 28;
+ // Opaque string, only present for kCustomCall.
+ string custom_call_opaque = 53;
+
// Shape of outfeed request.
xla.Shape outfeed_shape = 29;
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 601a008d9f..0e5920af7a 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -272,10 +272,11 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
<< "instruction " << instruction->name()
<< " has control successors and cannot be removed";
- TF_RET_CHECK(instruction_iterators_.count(instruction) != 0);
- auto inst_it = instruction_iterators_.at(instruction);
- (*inst_it)->set_parent(nullptr);
- instructions_.erase(inst_it);
+ auto inst_it = instruction_iterators_.find(instruction);
+ TF_RET_CHECK(inst_it != instruction_iterators_.end());
+ (*inst_it->second)->set_parent(nullptr);
+ instructions_.erase(inst_it->second);
+ instruction_iterators_.erase(inst_it);
return Status::OK();
}
@@ -916,13 +917,14 @@ std::unique_ptr<HloComputation> HloComputation::Clone(
return CloneWithReplacements(
/*replacements=*/std::unordered_map<const HloInstruction*,
std::unique_ptr<HloInstruction>>(),
- context, suffix);
+ /*extras=*/{}, context, suffix);
}
std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements,
- HloCloneContext* context, const string& suffix) {
+ absl::Span<HloInstruction*> extras, HloCloneContext* context,
+ const string& suffix) {
std::unique_ptr<HloCloneContext> context_ptr;
if (context == nullptr) {
context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix);
@@ -944,6 +946,9 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n";
std::vector<HloInstruction*> postorder;
+ for (HloInstruction* instr : extras) {
+ postorder.push_back(instr);
+ }
for (HloInstruction* instr : MakeInstructionPostOrder()) {
if (HloInstruction* replacement = replace(instr)) {
postorder.push_back(replacement);
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index a880e9ab30..936a53bd7e 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -227,7 +227,7 @@ class HloComputation {
void UpdateReachabilityThroughInstruction(
const HloInstruction* instruction, HloReachabilityMap* reachability_map);
- int64 instruction_count() const { return instructions_.size(); }
+ int64 instruction_count() const { return instruction_iterators_.size(); }
// Creates and returns a list of the embedded computations called by this
// computation. This includes all embedded computations called directly or
@@ -333,10 +333,13 @@ class HloComputation {
//
// If replacements maps a key to nullptr, we remove that instruction from the
// new computation.
+ // If additional instructions are used by instructions in replacement map,
+ // they must be passed in post-order in the extras span.
std::unique_ptr<HloComputation> CloneWithReplacements(
std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements,
- HloCloneContext* context = nullptr, const string& suffix = "clone");
+ absl::Span<HloInstruction*> extras, HloCloneContext* context = nullptr,
+ const string& suffix = "clone");
// Returns true if the given instruction can be removed from the computation.
// Parameter instructions cannot be removed without violating invariants of
@@ -436,7 +439,7 @@ class HloComputation {
// instruction pointer to location in the list for fast lookup.
using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
InstructionList instructions_;
- std::unordered_map<const HloInstruction*, InstructionList::iterator>
+ tensorflow::gtl::FlatMap<const HloInstruction*, InstructionList::iterator>
instruction_iterators_;
std::vector<HloInstruction*> param_instructions_;
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.h b/tensorflow/compiler/xla/service/hlo_constant_folding.h
index 4557983a9c..4a624cc7b8 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.h
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.h
@@ -23,7 +23,7 @@ namespace xla {
// A pass which performs constant folding in order to avoid unnecessary
// computation on constants.
-class HloConstantFolding : public HloPassInterface {
+class HloConstantFolding : public HloModulePass {
public:
absl::string_view name() const override { return "constant_folding"; }
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index b76c50bb5b..b2005d3c21 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
@@ -201,6 +202,44 @@ StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
HloInstruction::CreateMap(map_shape, operands, map_computation));
}
+StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
+ HloInstruction* init_value,
+ HloOpcode binary_opcode,
+ HloModule* module) {
+ DCHECK_NE(nullptr, module);
+ std::vector<int64> all_dims(ShapeUtil::Rank(operand->shape()));
+ std::iota(all_dims.begin(), all_dims.end(), 0);
+
+ auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});
+ HloComputation* reduce_computation;
+ {
+ HloComputation::Builder b(operand->name() + ".reduce_sub_computation");
+ auto lhs = b.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
+ auto rhs = b.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
+ b.AddInstruction(
+ HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs));
+ reduce_computation = module->AddEmbeddedComputation(b.Build());
+ }
+
+ return operand->parent()->AddInstruction(HloInstruction::CreateReduce(
+ scalar_shape, operand, init_value, all_dims, reduce_computation));
+}
+
+StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
+ HloInstruction* on_true,
+ HloInstruction* on_false) {
+ HloComputation* computation = pred->parent();
+ DCHECK_EQ(computation, on_true->parent());
+ DCHECK_EQ(computation, on_false->parent());
+ TF_ASSIGN_OR_RETURN(Shape select_shape,
+ ShapeInference::InferTernaryOpShape(
+ HloOpcode::kSelect, pred, on_true, on_false));
+ return computation->AddInstruction(HloInstruction::CreateTernary(
+ select_shape, HloOpcode::kSelect, pred, on_true, on_false));
+}
+
StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) {
CHECK_GT(n, 0);
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h
index b22058abb4..8e5ddbbd50 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.h
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -107,6 +108,35 @@ StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
HloComputation* map_computation);
+// Creates a Reduce HLO instruction and adds it to the computation containing
+// the operand. This will create the sub-computation needed for the reduction in
+// the given module. binary_opcode should represent a binary operation.
+StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
+ HloInstruction* init_value,
+ HloOpcode binary_opcode,
+ HloModule* module);
+
+// Creates a Select HLO instruction and adds it to the computation containing
+// the predicate. The on_true and on_false instructions must also be contained
+// in the same computation.
+StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
+ HloInstruction* on_true,
+ HloInstruction* on_false);
+
+// Creates an R1 Constant HLO instruction of the given PrimitiveType with the
+// given values and adds it to the given computation.
+template <typename NativeT>
+StatusOr<HloInstruction*> MakeR1ConstantHlo(HloComputation* computation,
+ PrimitiveType type,
+ absl::Span<const NativeT> values) {
+ Literal literal = LiteralUtil::CreateR1<NativeT>(values);
+ if (literal.shape().element_type() != type) {
+ TF_ASSIGN_OR_RETURN(literal, literal.Convert(type));
+ }
+ return computation->AddInstruction(
+ HloInstruction::CreateConstant(std::move(literal)));
+}
+
// -----------------------------------------------------------------------------
// Some other miscellaneous helpers to generate common HLO patterns. All of
// these add all the instructions they generate into the computation containing
diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h
index a28c03599a..e4857fd3fd 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.h
+++ b/tensorflow/compiler/xla/service/hlo_cse.h
@@ -25,7 +25,7 @@ namespace xla {
// and identical instructions with the same operands are commoned. The pass
// iterates over the instructions in topological order which enables the pass to
// find arbitrarily large common expressions.
-class HloCSE : public HloPassInterface {
+class HloCSE : public HloModulePass {
public:
// If is_layout_sensitive is true, then the simplifier preserves layout during
// transformation. Otherwise, layout is ignored.
diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h
index 1fe69b1395..4012042672 100644
--- a/tensorflow/compiler/xla/service/hlo_dce.h
+++ b/tensorflow/compiler/xla/service/hlo_dce.h
@@ -33,7 +33,7 @@ namespace xla {
//
// This pass does not remove dead parameter instructions, as parameter
// instructions cannot be deleted.
-class HloDCE : public HloPassInterface {
+class HloDCE : public HloModulePass {
public:
~HloDCE() override {}
absl::string_view name() const override { return "dce"; }
diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
index d36631fc2f..c0bf1b9e16 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
@@ -30,7 +30,7 @@ namespace xla {
// used to break an HLO graph edge connecting two instructions with different
// sharding. If a set of connected instructions have all the same sharding, no
// kDomain instruction will be placed.
-class HloDomainIsolator : public HloPassInterface {
+class HloDomainIsolator : public HloModulePass {
public:
// Creates a new kDomain instruction for the edge between the use instruction
// (the first HloInstruction argument), and the operand instruction (the
diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h
index 97bc8ef604..0fc30fb86c 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_remover.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h
@@ -26,7 +26,7 @@ namespace xla {
// Removes all the kDomain instructions of a given kind from the input module,
// and calls the normalizer to propagate the properties on the possibly new born
// instructions.
-class HloDomainRemover : public HloPassInterface {
+class HloDomainRemover : public HloModulePass {
public:
// Creates a new HloDomainRemover object tasked at removing all the kDomain
// instructions of a given kind.
diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h
index 81d6d69a8c..bea5cba38d 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h
@@ -29,7 +29,7 @@ namespace xla {
// Verifies that the domain instructions are consistent, and the each domain is
// surrounded by the same metadata.
-class HloDomainVerifier : public HloPassInterface {
+class HloDomainVerifier : public HloModulePass {
public:
HloDomainVerifier(std::vector<string> kinds) : kinds_(std::move(kinds)) {}
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h
index 44ded2c2fa..4d2a942925 100644
--- a/tensorflow/compiler/xla/service/hlo_element_type_converter.h
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h
@@ -25,7 +25,7 @@ namespace xla {
// inserting Convert ops. This allows a backend to support an element type while
// only actually implementing the Convert op for that element type. This is
// generally not the fastest approach, but it works.
-class HloElementTypeConverter : public HloPassInterface {
+class HloElementTypeConverter : public HloModulePass {
public:
// eliminate_type is the type to eliminate as the input or output of ops,
// using Convert ops to replace it with replace_with_type.
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 06b6d5b559..d7c39b2778 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -496,6 +496,61 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) {
return Status::OK();
}
+Status HloEvaluator::HandleReal(HloInstruction* real) {
+ auto operand = real->operand(0);
+ switch (operand->shape().element_type()) {
+ case BF16: {
+ auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>(
+ real, [](bfloat16 elem_operand) { return elem_operand; },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ case C64: {
+ auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
+ real, [](complex64 elem_operand) { return std::real(elem_operand); },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ case F16: {
+ auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>(
+ real, [](Eigen::half elem_operand) { return elem_operand; },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ case F32: {
+ auto result_or = ElementWiseUnaryOpImpl<float, float>(
+ real, [](float elem_operand) { return elem_operand; },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ case F64: {
+ auto result_or = ElementWiseUnaryOpImpl<double, double>(
+ real, [](double elem_operand) { return elem_operand; },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ default:
+ LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: "
+ << PrimitiveType_Name(operand->shape().element_type());
+ }
+
+ return Status::OK();
+}
+
+Status HloEvaluator::HandleImag(HloInstruction* imag) {
+ auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
+ imag, [](complex64 elem_operand) { return std::imag(elem_operand); },
+ GetEvaluatedLiteralFor(imag->operand(0)));
+
+ TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
+ return Status::OK();
+}
+
Status HloEvaluator::HandleCompare(HloInstruction* compare) {
HloOpcode opcode = compare->opcode();
auto lhs = compare->operand(0);
@@ -1173,80 +1228,85 @@ StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
TF_RET_CHECK(
ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape()))
<< "Sort keys and values must have the same dimensions";
- TF_RET_CHECK(rank > 0 && rank <= 2)
- << "Sort is only supported for rank-1 and rank-2 shapes, rank is: "
- << rank;
TF_RET_CHECK(sort->operand_count() == 2) << "Expected key-value sort";
- // We need to sort and array of keys and an array of values, where the
+ // We need to sort an array of keys and an array of values, where the
// sorted order of the values is determined by the keys. The simplest(?)
// way to do this is to go to an array-of-pairs representation, sort the
// array using the keys, and then go back to pair-of-arrays.
VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString();
VLOG(3) << "HandleSort values_literal: " << values_literal.ToString();
- auto sort_r1 = [](const Literal& keys_literal,
- const Literal& values_literal) {
- const auto& keys_data = keys_literal.data<KeyType>();
- const auto& values_data = values_literal.data<ValueType>();
-
- using kv_pair = std::pair<KeyType, ValueType>;
- std::vector<kv_pair> key_value_vector;
- CHECK_EQ(keys_data.size(), values_data.size());
- key_value_vector.reserve(keys_data.size());
- for (int i = 0; i < keys_data.size(); ++i) {
- key_value_vector.push_back(std::make_pair(keys_data[i], values_data[i]));
- }
- std::sort(key_value_vector.begin(), key_value_vector.end(),
- [](const kv_pair& a, const kv_pair& b) {
- return SafeLess<KeyType>(a.first, b.first);
- });
- std::vector<KeyType> result_keys;
- std::vector<ValueType> result_values;
- for (const auto& key_value : key_value_vector) {
- result_keys.push_back(key_value.first);
- result_values.push_back(key_value.second);
- }
- Literal result_keys_literal(keys_literal.shape());
- result_keys_literal.PopulateR1(absl::Span<const KeyType>(result_keys));
- Literal result_values_literal(values_literal.shape());
- result_values_literal.PopulateR1(
- absl::Span<const ValueType>(result_values));
- return std::make_pair(std::move(result_keys_literal),
- std::move(result_values_literal));
- };
-
- Literal result_tuple;
- if (rank == 1) {
- auto result_pair = sort_r1(keys_literal, values_literal);
- result_tuple =
- LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second});
- } else {
- // For R2 sort, the desired semantics are to sort each matrix row
- // independently.
- Literal keys_result_literal(keys_literal.shape());
- Literal values_result_literal(values_literal.shape());
- int64 r1_length = keys_literal.shape().dimensions(1);
- for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) {
- TF_ASSIGN_OR_RETURN(auto keys_r1_slice,
- keys_literal.Slice({row, 0}, {row + 1, r1_length})
- .Reshape({r1_length}));
- TF_ASSIGN_OR_RETURN(auto values_r1_slice,
- values_literal.Slice({row, 0}, {row + 1, r1_length})
- .Reshape({r1_length}));
- auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice);
- TF_ASSIGN_OR_RETURN(auto sorted_keys,
- r1_result_pair.first.Reshape({1, r1_length}));
- TF_ASSIGN_OR_RETURN(auto sorted_values,
- r1_result_pair.second.Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom(
- sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
- TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom(
- sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
- }
- result_tuple =
- LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal});
+ if (rank == 0) {
+ // Nothing to sort.
+ return LiteralUtil::MakeTuple({&keys_literal, &values_literal});
}
+ Literal keys_result_literal(keys_literal.shape());
+ Literal values_result_literal(values_literal.shape());
+ std::vector<int64> zero_base(rank, 0);
+ std::vector<int64> increment(rank, 1);
+ int64 sort_dim = sort->dimensions(0);
+ int64 sort_dim_elements = keys_literal.shape().dimensions(sort_dim);
+ increment[sort_dim] = sort_dim_elements;
+ // Iterate through each dimension except 'sort_dim'.
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
+ keys_literal.shape(), zero_base,
+ AsInt64Slice(keys_literal.shape().dimensions()), increment,
+ [&](absl::Span<const int64> indices) -> StatusOr<bool> {
+ // Extract a slice from the keys and values literals that correspond to
+ // exactly the row in dimension 'sort_dim'.
+ std::vector<int64> limit_indices(indices.begin(), indices.end());
+ std::for_each(limit_indices.begin(), limit_indices.end(),
+ [](int64& index) { ++index; });
+ limit_indices[sort_dim] = sort_dim_elements;
+ TF_ASSIGN_OR_RETURN(auto keys_to_sort,
+ keys_literal.Slice(indices, limit_indices)
+ .Reshape({sort_dim_elements}));
+ const auto& keys_data = keys_to_sort.data<KeyType>();
+ TF_ASSIGN_OR_RETURN(auto values_to_sort,
+ values_literal.Slice(indices, limit_indices)
+ .Reshape({sort_dim_elements}));
+ const auto& values_data = values_to_sort.data<ValueType>();
+ using kv_pair = std::pair<KeyType, ValueType>;
+ std::vector<kv_pair> key_value_vector;
+ key_value_vector.reserve(keys_data.size());
+ for (int i = 0; i < keys_data.size(); ++i) {
+ key_value_vector.push_back(
+ std::make_pair(keys_data[i], values_data[i]));
+ }
+ std::sort(key_value_vector.begin(), key_value_vector.end(),
+ [](const kv_pair& a, const kv_pair& b) {
+ return SafeLess<KeyType>(a.first, b.first);
+ });
+ std::vector<KeyType> result_keys;
+ std::vector<ValueType> result_values;
+ for (const auto& key_value : key_value_vector) {
+ result_keys.push_back(key_value.first);
+ result_values.push_back(key_value.second);
+ }
+ Literal sorted_keys(ShapeUtil::MakeShape(
+ keys_literal.shape().element_type(), {sort_dim_elements}));
+ sorted_keys.PopulateR1(absl::Span<const KeyType>(result_keys));
+ Literal sorted_values(ShapeUtil::MakeShape(
+ values_literal.shape().element_type(), {sort_dim_elements}));
+ sorted_values.PopulateR1(absl::Span<const ValueType>(result_values));
+ std::vector<int64> slice_dimensions(rank, 1);
+ slice_dimensions[sort_dim] = sort_dim_elements;
+ std::vector<int64> start_indices(rank, 0);
+ TF_ASSIGN_OR_RETURN(auto sorted_keys_reshaped,
+ sorted_keys.Reshape(slice_dimensions));
+ TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom(
+ sorted_keys_reshaped, start_indices, indices, slice_dimensions));
+ TF_ASSIGN_OR_RETURN(auto sorted_values_reshaped,
+ sorted_values.Reshape(slice_dimensions));
+ TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom(
+ sorted_values_reshaped, start_indices, indices, slice_dimensions));
+ return true;
+ }));
+
+ Literal result_tuple;
+ result_tuple =
+ LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal});
VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
return std::move(result_tuple);
}
@@ -1292,15 +1352,6 @@ StatusOr<Literal> EvaluateSort(HloInstruction* sort,
} // namespace
Status HloEvaluator::HandleSort(HloInstruction* sort) {
- const int64 sort_dim = sort->dimensions(0);
- const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape());
- if (sort_dim != rank - 1) {
- return Unimplemented(
- "Trying to sort along dimension %d, which is not the last "
- "dimension",
- sort_dim);
- }
-
if (!ShapeUtil::IsTuple(sort->shape())) {
return DefaultAction(sort);
} else {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 21e676d671..6c2662ebae 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -184,6 +184,10 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
Status HandleSort(HloInstruction* sort) override;
+ Status HandleReal(HloInstruction* real) override;
+
+ Status HandleImag(HloInstruction* imag) override;
+
Status HandleReduce(HloInstruction* reduce) override;
// Returns the already-evaluated literal result for the instruction.
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 01e88566a5..cee11a8a21 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -66,6 +66,20 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
.ConsumeValueOrDie();
}
+ // Evaluate function that takes in a local module instead of using module_
+ // that is in HloVerifiedTestBase. Once module_ in HloVerifiedTestBase is
+ // removed, this should be the default Evaluate function.
+ Literal EvaluateWithModule(
+ HloModule* module, absl::Span<const Literal* const> arg_literals = {}) {
+ if (use_bfloat16_) {
+ // In BF16 mode, we convert all F32 type to BF16 and evaluate the module.
+ auto type_converter = HloElementTypeConverter(F32, BF16);
+ type_converter.Run(module).ValueOrDie();
+ }
+ return evaluator_->Evaluate(*module->entry_computation(), arg_literals)
+ .ConsumeValueOrDie();
+ }
+
std::unique_ptr<HloEvaluator> evaluator_;
void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input,
@@ -2530,6 +2544,114 @@ ENTRY main {
expected, Evaluate({&operand, &scatter_indices, &updates})));
}
+TEST_P(HloEvaluatorTest, EvaluateScatter_NegativeIndices) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatter_NegativeIndices
+
+add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(s32[] lhs, s32[] rhs)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,3] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=add_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ Literal operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ // No updates should happen for the negative indices.
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({-1, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {77, 88, 99}}),
+ EvaluateWithModule(module.get(),
+ {&operand, &scatter_indices, &updates})));
+}
+
+TEST_P(HloEvaluatorTest, EvaluateScatter_OobIndices) {
+ const string hlo_text = R"(
+HloModule BatchDynamicUpdateSlice
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3]{1,0} parameter(0)
+ indices = s32[6,2]{1,0} parameter(1)
+ updates = s32[6,1,1]{2,1,0} parameter(2)
+ ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ Literal operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ // No updates should happen for the OOB indices.
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>(
+ {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
+ Literal updates = LiteralUtil::CreateR3<int32>(
+ {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 30, 60}, {7, 20, 9}}),
+ EvaluateWithModule(module.get(),
+ {&operand, &scatter_indices, &updates})));
+}
+
+TEST_P(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatterNd_OobUpdateWindow
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3,2] parameter(0)
+ indices = s32[1,2] parameter(1)
+ updates = s32[1,2,2] parameter(2)
+ ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ Literal operand =
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}});
+ Literal updates = LiteralUtil::CreateR3<int32>({{{-10, 10}, {-40, 40}}});
+ // Given the update window size of 2,2 and the index of 0,2, the update window
+ // will be OOB. So, nothing should be updated.
+ Literal expected = operand.Clone();
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ expected, EvaluateWithModule(module.get(),
+ {&operand, &scatter_indices, &updates})));
+}
+
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise comparison with 2 bfloat16 operands.
TEST_P(HloEvaluatorTest, DoesCompareBF16) {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 8fb17a0033..b2d12c94b8 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
+#include <cmath>
+
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
@@ -41,7 +43,9 @@ template <typename T>
using is_complex64_t = std::is_same<T, complex64>;
// It's UB to use std::sort with std::less<float>, because of NaNs. Define
-// "safe" less functions which are actually strict weak orders.
+// "safe" less functions which are actually strict weak orders. -NaN and NaN
+// should appear at the beginning and end of the ordering, and -0.0 should
+// appear before 0.0.
template <
typename NativeT,
typename std::enable_if<std::is_integral<NativeT>::value>::type* = nullptr>
@@ -49,26 +53,33 @@ bool SafeLess(const NativeT& a, const NativeT& b) {
return a < b;
}
-template <typename NativeT,
- typename std::enable_if<
- std::is_floating_point<NativeT>::value ||
- std::is_same<NativeT, bfloat16>::value>::type* = nullptr>
+template <typename NativeT, typename std::enable_if<std::is_floating_point<
+ NativeT>::value>::type* = nullptr>
bool SafeLess(const NativeT& a, const NativeT& b) {
- if (std::isnan(b)) {
- return !std::isnan(a);
- } else {
- return a < b;
+ bool lhs_is_negative = std::signbit(a);
+ bool rhs_is_negative = std::signbit(b);
+ // If the signs are different, we can just compare the signs.
+ if (lhs_is_negative != rhs_is_negative) {
+ return lhs_is_negative && !rhs_is_negative;
+ }
+ bool lhs_nan = std::isnan(a);
+ bool rhs_nan = std::isnan(b);
+ // Exactly one number is nan?
+ if (lhs_nan != rhs_nan) {
+ if (lhs_nan) {
+ return lhs_is_negative;
+ }
+ return !rhs_is_negative;
}
+ return a < b;
}
-template <typename NativeT, typename std::enable_if<std::is_same<
- NativeT, Eigen::half>::value>::type* = nullptr>
+template <typename NativeT,
+ typename std::enable_if<
+ std::is_same<NativeT, bfloat16>::value ||
+ std::is_same<NativeT, Eigen::half>::value>::type* = nullptr>
bool SafeLess(const NativeT& a, const NativeT& b) {
- if (Eigen::half_impl::isnan(b)) {
- return !Eigen::half_impl::isnan(a);
- } else {
- return a < b;
- }
+ return SafeLess(static_cast<float>(a), static_cast<float>(b));
}
// Templated DfsHloVisitor for use by HloEvaluator.
@@ -78,6 +89,8 @@ bool SafeLess(const NativeT& a, const NativeT& b) {
// to this rule, notably:
// - HandleCompare and HandleIsFinite: where the resulting literal type is
// always boolean.
+// - HandleImag and HandleReal: where the resulting literal type is always float
+// and the operand is always complex, or real in the case of HandleReal.
// These operations are handled outside of the parent HloEvaluator handlers
// instead of from within TypedVisitor.
//
@@ -318,14 +331,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleFloor<ReturnT>(floor);
}
- Status HandleImag(HloInstruction* imag) override {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[imag],
- ElementWiseUnaryOp(imag, [](ElementwiseT elem_operand) {
- return std::imag(elem_operand);
- }));
- return Status::OK();
- }
-
Status HandleLog(HloInstruction* log) override {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[log],
ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) {
@@ -673,14 +678,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
- Status HandleReal(HloInstruction* real) override {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[real],
- ElementWiseUnaryOp(real, [](ElementwiseT elem_operand) {
- return std::real(elem_operand);
- }));
- return Status::OK();
- }
-
template <typename NativeT, typename std::enable_if<std::is_floating_point<
NativeT>::value>::type* = nullptr>
Status HandleRemainder(HloInstruction* remainder) {
@@ -1527,47 +1524,55 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
!std::is_same<NativeT, bool>::value>::type* = nullptr>
Status HandleSort(HloInstruction* sort) {
auto keys = sort->operand(0);
- auto rank = ShapeUtil::Rank(keys->shape());
- TF_RET_CHECK(rank > 0 && rank <= 2)
- << "Sort is only supported for R1 and R2 shapes";
TF_RET_CHECK(sort->operand_count() == 1)
<< "Typed visitor does not support key-value sort";
const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys);
-
- auto sort_r1 = [this](const Literal& keys_literal) {
- VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString();
- const auto& keys_data = keys_literal.data<ReturnT>();
-
- std::vector<ReturnT> result_data(keys_data.begin(), keys_data.end());
- std::sort(result_data.begin(), result_data.end(),
- [](const ReturnT& a, const ReturnT& b) {
- return SafeLess<ReturnT>(a, b);
- });
- Literal result_literal(keys_literal.shape());
- result_literal.PopulateR1(absl::Span<const ReturnT>(result_data));
- VLOG(3) << "HandleSort result_literal: " << result_literal.ToString();
- return result_literal;
- };
-
- if (rank == 1) {
- parent_->evaluated_[sort] = std::move(sort_r1(keys_literal));
- } else {
- // For R2 sort, the desired semantics are to sort each matrix row
- // independently.
- Literal result_literal(keys_literal.shape());
- int64 r1_length = keys->shape().dimensions(1);
- for (int64 row = 0; row < keys->shape().dimensions(0); ++row) {
- TF_ASSIGN_OR_RETURN(auto r1_slice,
- keys_literal.Slice({row, 0}, {row + 1, r1_length})
- .Reshape({r1_length}));
- auto r1_result = sort_r1(r1_slice);
- TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
- r1_result, {0, 0}, {row, 0}, {1, r1_length}));
- }
- parent_->evaluated_[sort] = std::move(result_literal);
+ int64 sort_dim = sort->dimensions(0);
+ int64 sort_dim_elements = keys->shape().dimensions(sort_dim);
+ int64 rank = ShapeUtil::Rank(keys->shape());
+ if (rank == 0) {
+ // Nothing to sort.
+ parent_->evaluated_[sort] = keys_literal.Clone();
+ return Status::OK();
}
+ Literal result_literal(keys_literal.shape());
+ std::vector<int64> zero_base(rank, 0);
+ std::vector<int64> increment(rank, 1);
+ increment[sort_dim] = sort_dim_elements;
+ // Iterate through each dimension except 'sort_dim'.
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
+ keys->shape(), zero_base, AsInt64Slice(keys->shape().dimensions()),
+ increment, [&](absl::Span<const int64> indices) -> StatusOr<bool> {
+ // Extract a slice from the literal that corresponds to exactly the
+ // row in dimension 'sort_dim'.
+ std::vector<int64> limit_indices(indices.begin(), indices.end());
+ std::for_each(limit_indices.begin(), limit_indices.end(),
+ [](int64& index) { ++index; });
+ limit_indices[sort_dim] = sort_dim_elements;
+ TF_ASSIGN_OR_RETURN(auto row_to_sort,
+ keys_literal.Slice(indices, limit_indices)
+ .Reshape({sort_dim_elements}));
+ const auto& row_data = row_to_sort.data<NativeT>();
+
+ std::vector<NativeT> result_data(row_data.begin(), row_data.end());
+ std::sort(result_data.begin(), result_data.end(),
+ [](const NativeT& a, const NativeT& b) {
+ return SafeLess<NativeT>(a, b);
+ });
+ Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(),
+ {sort_dim_elements}));
+ sorted_row.PopulateR1(absl::Span<const NativeT>(result_data));
+ std::vector<int64> slice_dimensions(rank, 1);
+ slice_dimensions[sort_dim] = sort_dim_elements;
+ TF_ASSIGN_OR_RETURN(auto sorted_row_reshaped,
+ sorted_row.Reshape(slice_dimensions));
+ std::vector<int64> start_indices(rank, 0);
+ TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
+ sorted_row_reshaped, start_indices, indices, slice_dimensions));
+ return true;
+ }));
+ parent_->evaluated_[sort] = std::move(result_literal);
return Status::OK();
}
@@ -2265,19 +2270,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// be 1.
int64 update_dim_size =
update_dim == -1 ? 1 : updates_shape.dimensions(update_dim);
- // Clamp the scatter index so that the scatter region fits in the
- // operand. input_scatter_index_clamped[i] =
- // clamp(input_scatter_index[i], 0,
- // operand_shape.dimensions(i) -
- // update_dim_size);
- input_scatter_index_clamped[i] =
- std::min(operand_shape.dimensions(i) - update_dim_size,
- std::max(0LL, input_scatter_index[i]));
+ // If any part of the update region is out-of-bounds, then do not
+ // perform any update on the input.
+ if ((input_scatter_index[i] < 0) ||
+ (input_scatter_index[i] >
+ operand_shape.dimensions(i) - update_dim_size)) {
+ return true;
+ }
}
for (int i = 0, e = input_index.size(); i < e; i++) {
- input_index[i] = input_scatter_index_clamped[i] + input_window_index[i];
- DCHECK_GE(input_index[i], 0);
- DCHECK_LT(input_index[i], operand_shape.dimensions(i));
+ input_index[i] = input_scatter_index[i] + input_window_index[i];
}
auto result_value_literal =
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 287ba84b3b..13a74fd8a1 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -1110,7 +1110,7 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
instr->metadata().source_line()));
}
- return StrJoin(lines, "<br/>");
+ return StrJoin(lines, "\n");
}
string HloDotDumper::GetInstructionNodeBackendConfig(
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index e905f2983a..23787dbc8a 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -379,7 +379,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
break;
case HloOpcode::kCustomCall:
instruction = CreateCustomCall(proto.shape(), all_operands(),
- proto.custom_call_target());
+ proto.custom_call_target(),
+ proto.custom_call_opaque());
if (proto.has_window()) {
static_cast<HloCustomCallInstruction*>(instruction.get())
->set_window(proto.window());
@@ -1108,9 +1109,9 @@ bool HloInstruction::HasSideEffect() const {
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
- absl::string_view custom_call_target) {
- return absl::make_unique<HloCustomCallInstruction>(shape, operands,
- custom_call_target);
+ absl::string_view custom_call_target, absl::string_view opaque) {
+ return absl::make_unique<HloCustomCallInstruction>(
+ shape, operands, custom_call_target, opaque);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
@@ -2423,7 +2424,7 @@ template <typename Visitor>
static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
const InternalCompareFunction* operand_order,
bool ignore_control_predecessors) {
- visitor->ReserveVisitStates(root->GetModule()->NumUniqueInstructionIds());
+ visitor->ReserveVisitStates(root->GetModule()->instruction_count());
// dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
//
@@ -2910,6 +2911,26 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
return os << ToString(kind);
}
+bool HloPtrComparator::operator()(const HloInstruction* const& lhs,
+ const HloInstruction* const& rhs) const {
+ if (rhs == nullptr) {
+ // Nothing compares less than nullptr.
+ return false;
+ }
+ if (lhs == nullptr) {
+ return true;
+ }
+ auto lhs_module = lhs->GetModule();
+ auto rhs_module = rhs->GetModule();
+ CHECK((lhs_module == nullptr && rhs_module == nullptr) ||
+ (lhs_module != nullptr && rhs_module != nullptr));
+ if (lhs_module != nullptr &&
+ lhs_module->unique_id() != rhs_module->unique_id()) {
+ return lhs_module->unique_id() < rhs_module->unique_id();
+ }
+ return lhs->unique_id() < rhs->unique_id();
+}
+
bool HloInstruction::CouldBeBitcast() const {
switch (opcode_) {
case HloOpcode::kTranspose:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 4f6cac1396..009bd3bab3 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -718,10 +718,11 @@ class HloInstruction {
HloComputation* computation);
// Creates a custom call instruction that applies the given custom call target
- // to the given operands. "shape" is the resultant shape.
+ // to the given operands. "opaque" can be an arbitrary string with a
+ // backend-specific interpretation. "shape" is the resultant shape.
static std::unique_ptr<HloInstruction> CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
- absl::string_view custom_call_target);
+ absl::string_view custom_call_target, absl::string_view opaque = "");
// Creates a tuple instruction with the given elements. This is a convenience
// wrapper around CreateVariadic.
@@ -1616,6 +1617,10 @@ class HloInstruction {
InstructionVector operands_;
// The set of control predecessors of this instruction.
+ // Note that the order of the instructions in the vector influences the order
+ // computed in HloComputation::ComputeInstructionPostOrder, which may
+ // influence the result of the compilation by changing the scheduling. We are
+ // not sure if it matters.
std::vector<HloInstruction*> control_predecessors_;
// The users of this instruction. Users are HLOs where this instruction is an
@@ -1689,21 +1694,9 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
// To make the iteration order over the map deterministic, the comparator
// should not be using the pointer values, but rather an intrinsic property of
// the hlo. Exception: null pointer values compare less than non-null.
-//
-// Note that this cannot be used for HLO instructions across multiple modules
-// since the id of HLO instructions are only unique within each HLO module.
struct HloPtrComparator {
bool operator()(const HloInstruction* const& lhs,
- const HloInstruction* const& rhs) const {
- if (rhs == nullptr) {
- // Nothing compares less than nullptr.
- return false;
- }
- if (lhs == nullptr) {
- return true;
- }
- return lhs->unique_id() < rhs->unique_id();
- }
+ const HloInstruction* const& rhs) const;
};
template <typename ValueT>
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index e92882c22a..cd71bc3323 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1830,9 +1830,10 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
HloCustomCallInstruction::HloCustomCallInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
- absl::string_view custom_call_target)
+ absl::string_view custom_call_target, absl::string_view opaque)
: HloInstruction(HloOpcode::kCustomCall, shape),
custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
+ opaque_(opaque.begin(), opaque.end()),
feature_group_count_(1) {
for (auto operand : operands) {
AppendOperand(operand);
@@ -1849,6 +1850,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
*convolution_dimension_numbers_;
}
proto.set_custom_call_target(custom_call_target_);
+ proto.set_custom_call_opaque(opaque_);
proto.set_feature_group_count(feature_group_count_);
return proto;
}
@@ -1872,6 +1874,11 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
// an HloComputation.
extra.push_back(
StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
+ // If the opaque string becomes enormous we may want to reconsider printing
+ // this inline and consider other options.
+ if (!opaque_.empty()) {
+ extra.push_back(StrCat("opaque=\"", CEscape(opaque_), "\""));
+ }
return extra;
}
@@ -1897,7 +1904,8 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
if (feature_group_count_ != casted_other.feature_group_count_) {
return false;
}
- return custom_call_target_ == casted_other.custom_call_target_;
+ return custom_call_target_ == casted_other.custom_call_target_ &&
+ opaque_ == casted_other.opaque_;
}
std::unique_ptr<HloInstruction>
@@ -1905,7 +1913,7 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
auto cloned = absl::make_unique<HloCustomCallInstruction>(
- shape, new_operands, custom_call_target());
+ shape, new_operands, custom_call_target(), opaque());
if (window_ != nullptr) {
cloned->set_window(*window_);
}
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 2d7bc83855..9c22f5db7e 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1070,7 +1070,8 @@ class HloCustomCallInstruction : public HloInstruction {
public:
explicit HloCustomCallInstruction(const Shape& shape,
absl::Span<HloInstruction* const> operands,
- absl::string_view custom_call_target);
+ absl::string_view custom_call_target,
+ absl::string_view opaque);
const Window& window() const override {
CHECK(window_ != nullptr);
return *window_;
@@ -1090,6 +1091,7 @@ class HloCustomCallInstruction : public HloInstruction {
convolution_dimension_numbers_ =
absl::make_unique<ConvolutionDimensionNumbers>(dnums);
}
+ const string& opaque() const { return opaque_; }
const string& custom_call_target() const { return custom_call_target_; }
void set_feature_group_count(int64 feature_group_count) {
feature_group_count_ = feature_group_count;
@@ -1109,8 +1111,10 @@ class HloCustomCallInstruction : public HloInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
- // Name of a global symbol to call, only present for kCustomCall.
+ // Name of a global symbol to call.
string custom_call_target_;
+ // Opaque string interpreted by the backend.
+ string opaque_;
// Describes the window in a windowed operation such as convolution.
std::unique_ptr<Window> window_;
// Describes the dimension numbers used for a convolution.
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
index 3a1dd471c6..5bf055f3c0 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
@@ -219,6 +219,33 @@ void PropagateLivenessToParameterCallers(
}
}
+// Makes sure that if a live instruction is within a computation used in control
+// flow operations, we mark live even other related instructions.
+void PropagateLivenessThroughControlFlow(
+ const HloInstruction* instruction,
+ HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
+ Workset* workset, CallGraph* call_graph) {
+ const CallGraphNode& call_graph_node =
+ call_graph->GetNode(instruction->parent());
+ if (call_graph_node.context() == CallContext::kSequential) {
+ for (const CallSite& callsite : call_graph_node.caller_callsites()) {
+ HloInstruction* caller = callsite.instruction();
+ if (caller->opcode() == HloOpcode::kWhile) {
+ // If a live instruction is within the %while body or condition
+ // computation, mark the predicate value returned by the condition
+ // computation live as well.
+ MarkLiveAtIndex(caller->while_condition()->root_instruction(), {},
+ live_index_map, worklist, workset);
+ } else if (caller->opcode() == HloOpcode::kConditional) {
+ // If a live instruction is within the true or false branches of a
+ // conditional, we mark the predicate operand live as well.
+ MarkLiveAtIndex(caller->operand(0), {}, live_index_map, worklist,
+ workset);
+ }
+ }
+ }
+}
+
} // namespace
HloLivenessAnalysis::HloLivenessAnalysis(const HloModule& module)
@@ -257,12 +284,10 @@ void HloLivenessAnalysis::RunAnalysis() {
} else if (instruction->opcode() == HloOpcode::kGetTupleElement) {
PropagateLivenessThroughGTE(instruction, &live_index_map_, &worklist,
&workset);
- } else if (instruction->opcode() == HloOpcode::kWhile &&
- ShapeUtil::IsTuple(instruction->shape())) {
+ } else if (instruction->opcode() == HloOpcode::kWhile) {
PropagateLivenessThroughWhile(instruction, &live_index_map_, &worklist,
&workset);
- } else if (instruction->opcode() == HloOpcode::kParameter &&
- ShapeUtil::IsTuple(instruction->shape())) {
+ } else if (instruction->opcode() == HloOpcode::kParameter) {
PropagateLivenessToParameterCallers(instruction, &live_index_map_,
&worklist, &workset,
call_graph_.get());
@@ -277,6 +302,8 @@ void HloLivenessAnalysis::RunAnalysis() {
MarkLiveAtAllIndices(operand, &live_index_map_, &worklist, &workset);
}
}
+ PropagateLivenessThroughControlFlow(instruction, &live_index_map_,
+ &worklist, &workset, call_graph_.get());
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
index 01b625c29c..e0ae1173c6 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
@@ -398,5 +398,89 @@ TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) {
EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {2}));
}
+TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) {
+ auto module = ParseHloString(R"(
+ HloModule OutfeedLoop
+ WhileBody {
+ body_param = (s32[]) parameter(0)
+ token = token[] after-all()
+ constant.2 = s32[] constant(2)
+ outfeed_tuple = (s32[]) outfeed(constant.2, token)
+ get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ ROOT tuple = (s32[]) tuple(add)
+ }
+ WhileCondition {
+ cond_param = (s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
+ constant.2 = s32[] constant(10)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ ENTRY SimpleLoop {
+ constant.3 = s32[] constant(0)
+ tuple.1 = (s32[]) tuple(constant.3)
+ while = (s32[]) while(tuple.1), condition=WhileCondition,
+ body=WhileBody
+ ROOT rtuple = () tuple()
+ })")
+ .ValueOrDie();
+
+ const HloLivenessAnalysis& liveness = RunLiveness(module.get());
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
+}
+
+TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) {
+ auto module = ParseHloString(R"(
+ HloModule OutfeedLoop
+ InnerWhileBody {
+ body_param = (s32[]) parameter(0)
+ token = token[] after-all()
+ constant.2 = s32[] constant(2)
+ outfeed_tuple = (s32[]) outfeed(constant.2, token)
+ get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ ROOT tuple = (s32[]) tuple(add)
+ }
+ InnerWhileCondition {
+ cond_param = (s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
+ constant.2 = s32[] constant(10)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ OuterWhileCondition {
+ cond_param.2 = (s32[]) parameter(0)
+ get-tuple-element.5 = s32[] get-tuple-element(cond_param.2), index=0
+ constant.5 = s32[] constant(5)
+ ROOT less-than.2 = pred[] less-than(get-tuple-element.5, constant.5)
+ }
+ OuterWhileBody {
+ body_param.2 = (s32[]) parameter(0)
+ get-tuple-element.8 = s32[] get-tuple-element(body_param.2), index=0
+ constant.6 = s32[] constant(0)
+ tuple.2 = (s32[]) tuple(constant.6)
+ inner_while = (s32[]) while(tuple.2), condition=InnerWhileCondition,
+ body=InnerWhileBody
+ constant.7 = s32[] constant(1)
+ add.2 = s32[] add(get-tuple-element.8, constant.7)
+ ROOT rtuple = (s32[]) tuple(add.2)
+ }
+ ENTRY SimpleLoop {
+ constant.3 = s32[] constant(0)
+ tuple.1 = (s32[]) tuple(constant.3)
+ while = (s32[]) while(tuple.1), condition=OuterWhileCondition,
+ body=OuterWhileBody
+ ROOT rtuple = () tuple()
+ })")
+ .ValueOrDie();
+
+ const HloLivenessAnalysis& liveness = RunLiveness(module.get());
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
index c7ec88d450..6a4e766788 100644
--- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
@@ -400,7 +400,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
memory_by_computation) {
// These variables are a hack to prevent overflows.
int64 cumulative_total_size = 0;
- int64 total_hlos = computation.parent()->NumUniqueInstructionIds();
+ int64 total_hlos = computation.parent()->instruction_count();
tensorflow::gtl::FlatMap<const HloInstruction*, int64> extra_users;
tensorflow::gtl::FlatMap<const HloInstruction*, int64> total_sizes;
for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) {
diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
index 5e02868eba..9964c6fdd7 100644
--- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
@@ -90,7 +90,7 @@ StatusOr<HloInstructionSequence> ScheduleComputation(
// A pass which schedules the HLO instructions in a module. The HloModule's
// schedule field is set to the resulting HloSchedule using
// HloModule::set_schedule.
-class HloMemoryScheduler : public HloPassInterface {
+class HloMemoryScheduler : public HloModulePass {
public:
// size_function is the function returning the number of bytes required for a
// LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not
@@ -109,7 +109,7 @@ class HloMemoryScheduler : public HloPassInterface {
// A trivial pass which clears the schedule currently set on the
// HloModule. After this pass runs HloModudle::has_schedule will return false.
-class HloDescheduler : public HloPassInterface {
+class HloDescheduler : public HloModulePass {
public:
HloDescheduler() = default;
~HloDescheduler() override = default;
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 3bc2d13781..735804e827 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -63,6 +63,7 @@ class HloModule {
// tests). The versioned handle is used by the service in the compilation
// cache. A default configuration is created for this module.
explicit HloModule(const string& name, const HloModuleConfig& config);
+ virtual ~HloModule() {}
// Adds an entry computation to the module. A module can only have one entry
// computation. Returns a pointer to the newly added computation.
@@ -87,6 +88,7 @@ class HloModule {
const std::unordered_map<HloComputation*, HloComputation*>& replacements);
const string& name() const { return name_; }
+ void set_name(string name) { name_ = std::move(name); }
// Returns a deep copy of this module including all computations.
std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const;
@@ -255,7 +257,7 @@ class HloModule {
std::unique_ptr<HloComputation> computation, bool is_entry,
bool uniquify_identifiers);
- const string name_;
+ string name_;
HloModuleConfig config_;
HloComputation* entry_computation_ = nullptr;
std::vector<std::unique_ptr<HloComputation>> computations_;
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc
index f7be5cae22..31d26cc51e 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc
@@ -50,9 +50,7 @@ StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
auto* while_body_root = while_body_comp->root_instruction();
if (!ShapeUtil::IsTuple(xla_while->shape()) ||
- while_body_root->opcode() != HloOpcode::kTuple ||
- while_body_comp->HasSideEffect() ||
- xla_while->while_condition()->HasSideEffect()) {
+ while_body_root->opcode() != HloOpcode::kTuple) {
// Only run DCE on tuple-shaped while loops where body root is Tuple,
// with no I/O instructions.
VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString();
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h
index 12ca2340a6..d472211d2a 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce.h
+++ b/tensorflow/compiler/xla/service/hlo_module_dce.h
@@ -28,7 +28,7 @@ namespace xla {
// Sweeps through live instructions which cross computation boundaries (kWhile),
// and removes code at dead shape indices.
//
-class HloModuleDCE : public HloPassInterface {
+class HloModuleDCE : public HloModulePass {
public:
~HloModuleDCE() override {}
absl::string_view name() const override { return "hlo-module-dce"; }
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index 9c01862a4b..83352ef91b 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -392,22 +392,28 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1,
if (!ContainsKey(companion_set_index_, instruction1) &&
!ContainsKey(companion_set_index_, instruction2)) {
companion_sets_.push_back(
- absl::make_unique<std::unordered_set<HloInstruction*>>());
+ absl::make_unique<std::vector<HloInstruction*>>());
auto companion_set = companion_sets_.back().get();
- companion_set->insert(instruction1);
- companion_set->insert(instruction2);
+ companion_set->push_back(instruction1);
+ companion_set->push_back(instruction2);
companion_set_index_[instruction1] = companion_sets_.size() - 1;
companion_set_index_[instruction2] = companion_sets_.size() - 1;
} else if (!ContainsKey(companion_set_index_, instruction1)) {
- companion_sets_[companion_set_index_[instruction2]]->insert(instruction1);
+ companion_sets_[companion_set_index_[instruction2]]->push_back(
+ instruction1);
companion_set_index_[instruction1] = companion_set_index_[instruction2];
} else if (!ContainsKey(companion_set_index_, instruction2)) {
- companion_sets_[companion_set_index_[instruction1]]->insert(instruction2);
+ companion_sets_[companion_set_index_[instruction1]]->push_back(
+ instruction2);
companion_set_index_[instruction2] = companion_set_index_[instruction1];
} else if (companion_set_index_[instruction1] !=
companion_set_index_[instruction2]) {
- companion_sets_[companion_set_index_[instruction1]]->insert(
- Companions(instruction2).begin(), Companions(instruction2).end());
+ // At any point while building the companion sets, each instruction belongs
+ // to at most 1 companion set, so the union of two companion sets is
+ // concatenating two disjoint sets.
+ absl::c_copy(Companions(instruction2),
+ std::back_inserter(
+ *companion_sets_[companion_set_index_[instruction1]]));
int64 index_to_remove = companion_set_index_[instruction2];
for (HloInstruction* hlo : Companions(instruction2)) {
companion_set_index_[hlo] = companion_set_index_[instruction1];
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index 768b0c7eb3..278d94cdd3 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -169,14 +169,14 @@ class HloModuleGroupMetadata {
// Returns the companion instructions for the given instruction.
//
// Precondition: IsCompanionWhile(instruction) is true.
- const std::unordered_set<HloInstruction*>& Companions(
+ const std::vector<HloInstruction*>& Companions(
const HloInstruction* instruction) const {
CHECK_EQ(companion_set_index_.count(instruction), 1);
return companion_set(companion_set_index_.at(instruction));
}
// Returns the companion set at the given index.
- const std::unordered_set<HloInstruction*>& companion_set(int64 index) const {
+ const std::vector<HloInstruction*>& companion_set(int64 index) const {
CHECK_LT(index, companion_sets_.size());
return *companion_sets_[index];
}
@@ -187,7 +187,7 @@ class HloModuleGroupMetadata {
}
// Returns the list of all companion sets in the HLO module group.
- const std::vector<std::unique_ptr<std::unordered_set<HloInstruction*>>>&
+ const std::vector<std::unique_ptr<std::vector<HloInstruction*>>>&
companion_sets() const {
return companion_sets_;
}
@@ -247,8 +247,7 @@ class HloModuleGroupMetadata {
void DumpCollectedStats() const;
// List of all companion instructions sets in the module.
- std::vector<std::unique_ptr<std::unordered_set<HloInstruction*>>>
- companion_sets_;
+ std::vector<std::unique_ptr<std::vector<HloInstruction*>>> companion_sets_;
// Map from each companion while instruction to the index into companion_set_.
tensorflow::gtl::FlatMap<const HloInstruction*, int64> companion_set_index_;
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc
index ebf790ba6f..b7b12cb72b 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@@ -137,6 +138,69 @@ ENTRY %entry (a: f32[]) -> f32[] {
::testing::ElementsAre(op::Parameter()));
}
+// Tests that the order of companion instructions in the companion set doesn't
+// change across runs.
+TEST_F(HloModuleGroupTest, ModuleGroupCompanionOrder) {
+ // A simple while loop template for core i sending to core i+1.
+ constexpr char text[] = R"(
+HloModule module_%d
+
+while_cond {
+ ROOT p = pred[] constant(true)
+}
+
+while_body {
+ param = s32[] parameter(0)
+ token.s = token[] after-all()
+ token.r = token[] after-all()
+ send = (s32[], u32[], token[]) send(param, token.s), channel_id=%d
+ send-done = token[] send-done(send), channel_id=%d
+ recv = (s32[], u32[], token[]) recv(token.r), channel_id=%d
+ ROOT recv-done = (s32[], token[]) recv-done(recv), channel_id=%d
+}
+
+ENTRY entry {
+ while_init = s32[] constant(1)
+ ROOT while = s32[] while(while_init), condition=while_cond, body=while_body
+}
+)";
+
+ // Try creating the module and the metadata kTrialCount times and check the
+ // companion instructions remain in the same order.
+ const int64 kTrialCount = 5;
+ const int64 kDeviceCount = 10;
+ std::vector<int64> companion_order;
+
+ for (int64 t = 0; t < kTrialCount; ++t) {
+ HloModuleGroup group(TestName());
+ for (int64 i = 0; i < kDeviceCount; ++i) {
+ const int64 send_channel = i;
+ const int64 recv_channel = i == 0 ? kDeviceCount - 1 : i - 1;
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ ParseHloString(absl::StrFormat(text, i, send_channel, send_channel,
+ recv_channel, recv_channel)));
+ group.push_back(std::move(module));
+ }
+ ASSERT_EQ(group.modules().size(), kDeviceCount);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto metadata,
+ HloModuleGroupMetadata::Build(group.modules()));
+ ASSERT_EQ(metadata->companion_sets().size(), 1);
+
+ std::vector<int64> module_ids;
+ for (HloInstruction* companion : *metadata->companion_sets()[0]) {
+ module_ids.push_back(metadata->GetModuleId(companion->GetModule()));
+ }
+
+ if (t == 0) {
+ companion_order = module_ids;
+ } else {
+ EXPECT_TRUE(absl::c_equal(companion_order, module_ids));
+ }
+ }
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 11caa89c54..25b70740e3 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -64,14 +64,11 @@ class HloParser {
public:
using LocTy = HloLexer::LocTy;
- explicit HloParser(absl::string_view str, const HloModuleConfig& config)
- : lexer_(str), config_(config) {}
+ explicit HloParser(absl::string_view str) : lexer_(str) {}
- // Runs the parser. Returns false if an error occurred.
- bool Run();
-
- // Returns the parsed HloModule.
- std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); }
+ // Runs the parser and constructs the resulting HLO in the given (empty)
+ // HloModule. Returns false if an error occurred.
+ bool Run(HloModule* module);
// Returns the error information.
string GetError() const { return StrJoin(error_, "\n"); }
@@ -98,8 +95,8 @@ class HloParser {
const string& name, const optional<Shape>& shape = nullopt);
// ParseXXX returns false if an error occurred.
- bool ParseHloModule();
- bool ParseComputations();
+ bool ParseHloModule(HloModule* module);
+ bool ParseComputations(HloModule* module);
bool ParseComputation(HloComputation** entry_computation);
bool ParseInstructionList(HloComputation::Builder* builder,
string* root_name);
@@ -293,9 +290,7 @@ class HloParser {
computation_pool_;
HloLexer lexer_;
- std::unique_ptr<HloModule> module_;
std::vector<std::unique_ptr<HloComputation>> computations_;
- const HloModuleConfig config_;
std::vector<string> error_;
// Function that gets invoked when we try to resolve an instruction
@@ -349,9 +344,9 @@ bool HloParser::TokenError(absl::string_view msg) {
return Error(lexer_.GetLoc(), msg);
}
-bool HloParser::Run() {
+bool HloParser::Run(HloModule* module) {
lexer_.Lex();
- return ParseHloModule();
+ return ParseHloModule(module);
}
std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
@@ -366,7 +361,7 @@ std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
}
// ::= 'HloModule' name computations
-bool HloParser::ParseHloModule() {
+bool HloParser::ParseHloModule(HloModule* module) {
if (lexer_.GetKind() != TokKind::kw_HloModule) {
return TokenError("expects HloModule");
}
@@ -385,22 +380,20 @@ bool HloParser::ParseHloModule() {
return false;
}
- module_ = absl::make_unique<HloModule>(name, config_);
-
- if (!ParseComputations()) {
+ module->set_name(name);
+ if (!ParseComputations(module)) {
return false;
}
if (is_scheduled.has_value() && *is_scheduled) {
- TF_CHECK_OK(
- module_->set_schedule(ScheduleFromInstructionOrder(module_.get())));
+ TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
}
return true;
}
// computations ::= (computation)+
-bool HloParser::ParseComputations() {
+bool HloParser::ParseComputations(HloModule* module) {
HloComputation* entry_computation = nullptr;
do {
if (!ParseComputation(&entry_computation)) {
@@ -416,21 +409,20 @@ bool HloParser::ParseComputations() {
if ((entry_computation != nullptr &&
computations_[i].get() != entry_computation) ||
(entry_computation == nullptr && i != computations_.size() - 1)) {
- module_->AddEmbeddedComputation(std::move(computations_[i]));
+ module->AddEmbeddedComputation(std::move(computations_[i]));
continue;
}
- auto computation =
- module_->AddEntryComputation(std::move(computations_[i]));
+ auto computation = module->AddEntryComputation(std::move(computations_[i]));
// The parameters and result layouts were set to default layout. Here we
// set the layouts to what the hlo text says.
for (int p = 0; p < computation->num_parameters(); p++) {
const Shape& param_shape = computation->parameter_instruction(p)->shape();
- TF_CHECK_OK(module_->mutable_entry_computation_layout()
+ TF_CHECK_OK(module->mutable_entry_computation_layout()
->mutable_parameter_layout(p)
->CopyLayoutFromShape(param_shape));
}
const Shape& result_shape = computation->root_instruction()->shape();
- TF_CHECK_OK(module_->mutable_entry_computation_layout()
+ TF_CHECK_OK(module->mutable_entry_computation_layout()
->mutable_result_layout()
->CopyLayoutFromShape(result_shape));
}
@@ -1274,11 +1266,13 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
case HloOpcode::kCustomCall: {
optional<string> custom_call_target;
+ optional<string> opaque;
optional<Window> window;
optional<ConvolutionDimensionNumbers> dnums;
optional<int64> feature_group_count;
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
&custom_call_target};
+ attrs["opaque"] = {/*required=*/false, AttrTy::kString, &opaque};
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
attrs["dim_labels"] = {/*required=*/false,
AttrTy::kConvolutionDimensionNumbers, &dnums};
@@ -1287,8 +1281,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
- shape, operands, *custom_call_target));
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateCustomCall(shape, operands, *custom_call_target,
+ opaque.has_value() ? *opaque : ""));
if (window.has_value()) {
instruction->set_window(*window);
}
@@ -3247,53 +3242,62 @@ Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder,
StatusOr<std::unique_ptr<HloModule>> ParseHloString(
absl::string_view str, const HloModuleConfig& config) {
- HloParser parser(str, config);
- if (!parser.Run()) {
+ auto module = absl::make_unique<HloModule>(/*name=*/"", config);
+ HloParser parser(str);
+ if (!parser.Run(module.get())) {
return InvalidArgument("Syntax error:\n%s", parser.GetError());
}
- return parser.ConsumeHloModule();
+ return std::move(module);
}
StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str) {
- HloModuleConfig config;
- return ParseHloString(str, config);
+ auto module = absl::make_unique<HloModule>(/*name=*/"", HloModuleConfig());
+ HloParser parser(str);
+ if (!parser.Run(module.get())) {
+ return InvalidArgument("Syntax error:\n%s", parser.GetError());
+ }
+ return std::move(module);
+}
+
+Status ParseHloString(absl::string_view str, HloModule* module) {
+ TF_RET_CHECK(module->computation_count() == 0);
+ HloParser parser(str);
+ if (!parser.Run(module)) {
+ return InvalidArgument("Syntax error:\n%s", parser.GetError());
+ }
+ return Status::OK();
}
StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
absl::string_view str, absl::string_view name) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
auto builder = absl::make_unique<HloComputation::Builder>(string(name));
string root_name;
TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name));
std::unique_ptr<HloComputation> computation = builder->Build();
- auto module = absl::make_unique<HloModule>(string(name), config);
+ auto module = absl::make_unique<HloModule>(string(name), HloModuleConfig());
module->AddEntryComputation(std::move(computation));
return std::move(module);
}
StatusOr<HloSharding> ParseSharding(absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParseShardingOnly();
}
StatusOr<Window> ParseWindow(absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParseWindowOnly();
}
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParseConvolutionDimensionNumbersOnly();
}
StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParsePaddingConfigOnly();
}
diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h
index 1882a184da..3696035514 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.h
+++ b/tensorflow/compiler/xla/service/hlo_parser.h
@@ -30,18 +30,23 @@ namespace xla {
// For details about the syntax accepted by this parser, see
// g3doc/hlo_parser.md.
-// The api of the hlo parser. Given a string in the HloModule::ToString()
-// format, parses the string and creates a HloModule with the given config.
+// Given a string in the HloModule::ToString() format, parses the string and
+// creates a HloModule with the given config.
StatusOr<std::unique_ptr<HloModule>> ParseHloString(
absl::string_view str, const HloModuleConfig& config);
+// Given a string in the HloModule::ToString() format, parses the string and
+// builds the HloModule in place at the given module pointer. 'module' must
+// point to an empty module (no computations).
+Status ParseHloString(absl::string_view str, HloModule* module);
+
// Parses the text for a single HLO operation into an HLO module with a function
// that runs that operation (with the same parameters) as its entry computation.
StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
absl::string_view str, absl::string_view name = "single_op");
-// The api of the hlo parser. Given a string in the HloModule::ToString()
-// format, parses the string and creates a HloModule with default config.
+// Given a string in the HloModule::ToString() format, parses the string and
+// creates a HloModule with default config.
StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str);
// Parses the result of HloSharding::ToString(), e.g. "{replicated}".
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index cca50fab54..96db96bdb9 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1004,6 +1004,18 @@ ENTRY CustomCall {
)"
},
+// CustomCall with opaque value.
+{
+"CustomCallWithOpaque",
+R"(HloModule custom_call
+
+ENTRY CustomCall {
+ constant = f32[1]{0} constant({12345})
+ ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", opaque="this string is opaque"
+}
+
+)"
+},
// Variables with non-default names
{
"NonDefaultNames",
diff --git a/tensorflow/compiler/xla/service/hlo_pass_interface.h b/tensorflow/compiler/xla/service/hlo_pass_interface.h
index f1ad0f9b01..fdaac34386 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_interface.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_interface.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_group.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@@ -25,15 +26,45 @@ limitations under the License.
namespace xla {
// Base class for HLO passes. These are used with the HloPassPipeline to
-// organize a sequence of passes.
+// organize a sequence of passes. An HLO pass should not extend this class
+// directly; it should extend HloModulePass or HloModuleGroupPass.
class HloPassInterface {
public:
virtual ~HloPassInterface() = default;
virtual absl::string_view name() const = 0;
- // Run the pass on the given HLO module. Return whether it modified the
+ // Run the pass on the given HLO module. Returns whether it modified the
// module.
virtual StatusOr<bool> Run(HloModule* module) = 0;
+
+ // Run the pass on the given HLO module group. Returns whether it modified the
+ // module group. Ideally, the module group variant would be named "Run" as
+ // well, but C++ does not handle overloaded virtual methods well.
+ virtual StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) = 0;
+};
+
+// Base class for passes which are module-scoped.
+class HloModulePass : public HloPassInterface {
+ public:
+ // Runs the pass on a module group by iterating through each module in the
+ // group.
+ StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override {
+ bool changed = false;
+ for (HloModule* module : module_group->modules()) {
+ TF_ASSIGN_OR_RETURN(bool module_changed, Run(module));
+ changed |= module_changed;
+ }
+ return changed;
+ };
+};
+
+// Base class for passes which are module-group scoped. These passes cannot run
+// on an HLO module.
+class HloModuleGroupPass : public HloPassInterface {
+ public:
+ StatusOr<bool> Run(HloModule* module) override {
+ return InternalError("Module group pass cannot be run on a module");
+ }
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
index 6e4ed0de62..8c2f928ca1 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
@@ -17,7 +17,6 @@ limitations under the License.
#include <functional>
-#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
@@ -29,108 +28,128 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
namespace xla {
-namespace {
-using absl::StrAppend;
-using absl::StrCat;
-
-void DumpModuleGraph(const HloModule& module, const string& message) {
- hlo_graph_dumper::MaybeDumpHloModule(module, message);
- VLOG(3) << "HLO " << message << ":";
- XLA_VLOG_LINES(3, module.ToString());
+template <typename HloT>
+Status HloPassPipeline::RunInvariantCheckers(
+ HloT* hlo, absl::string_view after_pass_name) {
+ for (auto& invariant_checker : invariant_checkers_) {
+ VLOG(1) << " Invariant checker " << invariant_checker->name();
+ StatusOr<bool> changed_status = RunHelper(invariant_checker.get(), hlo);
+ VLOG(1) << " Invariant checker done " << invariant_checker->name();
+ if (!changed_status.ok()) {
+ VLOG(2) << "Failed invariant check:";
+ XLA_VLOG_LINES(2, hlo->ToString());
+ return Status(changed_status.status().code(),
+ absl::StrCat(changed_status.status().error_message(),
+ "\n\nFailed after ", after_pass_name));
+ }
+ TF_RET_CHECK(!changed_status.ValueOrDie())
+ << "invariant checkers must not change the graph";
+ }
+ return Status::OK();
}
-void DumpModuleProto(const HloModule& module, const string& dump_to,
- const string& pipeline_name, const string& pass_name) {
- static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
- static auto* const module_id_to_pass_number =
- new tensorflow::gtl::FlatMap<int64, int64>();
-
- tensorflow::mutex_lock lock(mu);
- const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++;
+template <typename HloT>
+StatusOr<bool> HloPassPipeline::RunPassesInternal(
+ HloT* hlo, absl::Span<HloPassInterface* const> passes) {
+ string last_pass_name = "pipeline-start";
+ TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name));
+ bool changed = false;
+ for (HloPassInterface* pass : passes) {
+ VLOG(1) << " HLO pass " << pass->name();
+ MaybeDumpHlo(*hlo,
+ /*after_pass_name=*/last_pass_name,
+ /*before_pass_name=*/pass->name());
+ TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
+ changed |= pass_changed;
+ TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass->name()));
+ last_pass_name = string(pass->name());
+ }
+ MaybeDumpHlo(*hlo,
+ /*after_pass_name=*/last_pass_name,
+ /*before_pass_name=*/"pipeline-end");
+ return changed;
+}
- const string mod_name = SanitizeFileName(
- absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(),
- pass_number, pipeline_name, pass_name));
+std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
+ const DebugOptions& debug_options) {
+ auto repeated_field = debug_options.xla_disable_hlo_passes();
+ tensorflow::gtl::FlatSet<string> disabled_pass_names(repeated_field.begin(),
+ repeated_field.end());
+ if (!disabled_pass_names.empty()) {
+ VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
+ << absl::StrJoin(disabled_pass_names, ", ");
+ }
- TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module),
- dump_to, mod_name));
+ std::vector<HloPassInterface*> enabled_passes;
+ for (auto& pass : passes_) {
+ if (disabled_pass_names.count(string(pass->name())) == 0) {
+ enabled_passes.push_back(pass.get());
+ }
+ }
+ return enabled_passes;
}
-} // namespace
-StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
- run_called_ = true;
+void HloPassPipeline::MaybeDumpHlo(const HloModule& module,
+ absl::string_view after_pass_name,
+ absl::string_view before_pass_name) {
+ const string& proto_dump_path =
+ module.config().debug_options().xla_dump_per_pass_hlo_proto_to();
+ if (!proto_dump_path.empty()) {
+ static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
+ static auto* const module_id_to_pass_number =
+ new tensorflow::gtl::FlatMap<int64, int64>();
+
+ tensorflow::mutex_lock lock(mu);
+ const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++;
+
+ const string filename = SanitizeFileName(
+ absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(),
+ pass_number, name(), after_pass_name));
+
+ TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(
+ MakeHloProto(module), proto_dump_path, filename));
+ }
- VLOG(1) << "Running HLO pass pipeline " << name();
+ const string message =
+ StrCat("after ", after_pass_name, ", before ", before_pass_name);
+ hlo_graph_dumper::MaybeDumpHloModule(module, message);
+ VLOG(3) << "HLO " << message << ":";
+ XLA_VLOG_LINES(3, module.ToString());
+}
- auto repeated_field =
- module->config().debug_options().xla_disable_hlo_passes();
- tensorflow::gtl::FlatSet<string> disabled_passes(repeated_field.begin(),
- repeated_field.end());
- if (!disabled_passes.empty()) {
- VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
- << absl::StrJoin(disabled_passes, ", ");
+void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group,
+ absl::string_view after_pass_name,
+ absl::string_view before_pass_name) {
+ for (const HloModule* module : module_group.modules()) {
+ MaybeDumpHlo(*module, after_pass_name, before_pass_name);
}
+}
- auto run_invariant_checkers = [this,
- module](const string& message) -> Status {
- for (auto& invariant_checker : invariant_checkers_) {
- VLOG(1) << " Invariant checker " << invariant_checker->name();
- StatusOr<bool> changed_status = invariant_checker->Run(module);
- VLOG(1) << " Invariant checker done " << invariant_checker->name();
- if (!changed_status.ok()) {
- VLOG(2) << "Module failed invariant check:";
- XLA_VLOG_LINES(2, module->ToString());
- return Status(changed_status.status().code(),
- StrCat(changed_status.status().error_message(),
- "\n\nFailed ", message));
- }
- TF_RET_CHECK(!changed_status.ValueOrDie())
- << "invariant checkers must not change the graph";
- }
- return Status::OK();
- };
+StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
+ run_called_ = true;
- string prefix = StrCat(name(), ": pipeline start");
- bool changed = false;
- string message;
- TF_RETURN_IF_ERROR(
- run_invariant_checkers(StrCat("before running pipeline: ", name())));
- const string xla_dump_per_pass_hlo_proto_to =
- module->config().debug_options().xla_dump_per_pass_hlo_proto_to();
- if (!xla_dump_per_pass_hlo_proto_to.empty()) {
- DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()),
- "pipeline_start");
- }
+ VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": "
+ << name();
- for (auto& pass : passes_) {
- if (disabled_passes.count(string(pass->name())) > 0) {
- VLOG(1) << " Skipping HLO pass " << pass->name()
- << ", disabled by --xla_disable_hlo_passes";
- continue;
- }
+ return RunPassesInternal(module,
+ GetEnabledPasses(module->config().debug_options()));
+}
- VLOG(1) << " HLO pass " << pass->name();
+StatusOr<bool> HloPassPipeline::RunOnModuleGroup(HloModuleGroup* module_group) {
+ run_called_ = true;
- // Emit label containing: "after foo-pass, before bar-pass".
- message.clear();
- StrAppend(&message, prefix, ", before ", pass->name());
- DumpModuleGraph(*module, message);
-
- TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module));
- TF_RETURN_IF_ERROR(
- run_invariant_checkers(StrCat("after running pass: ", pass->name())));
- if (!xla_dump_per_pass_hlo_proto_to.empty()) {
- DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()),
- string(pass->name()));
- }
+ VLOG(1) << "Running HLO pass pipeline on module group "
+ << module_group->name() << ": " << name();
- changed |= changed_this_pass;
- prefix.clear();
- StrAppend(&prefix, name(), ": after ", pass->name());
+ if (module_group->modules().empty()) {
+ VLOG(1) << "Module group is empty. Nothing to do.";
+ return false;
}
- DumpModuleGraph(*module, prefix + ", pipeline end");
- return changed;
+
+ return RunPassesInternal(
+ module_group,
+ GetEnabledPasses(module_group->module(0).config().debug_options()));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
index 1d41a4dac1..09e7033ea4 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -61,10 +62,45 @@ class HloPassPipeline : public HloPassInterface {
return *pass;
}
- // Run all passes on the given HLO module.
StatusOr<bool> Run(HloModule* module) override;
+ StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override;
private:
+ // Returns the set of passes which are enabled. DebugOptions can selectively
+ // disable passes via --xla_disable_hlo_passes flag.
+ std::vector<HloPassInterface*> GetEnabledPasses(
+ const DebugOptions& debug_options);
+
+ // Maybe dumps the given module or module group depending on flag values
+ // contained in DebugOptions of module config.
+ void MaybeDumpHlo(const HloModuleGroup& module_group,
+ absl::string_view after_pass_name,
+ absl::string_view before_pass_name);
+ void MaybeDumpHlo(const HloModule& module, absl::string_view after_pass_name,
+ absl::string_view before_pass_name);
+
+ // Runs the invariant checker on the given HLO. HloT can be either HloModule
+ // or HloModuleGroup.
+ template <typename HloT>
+ Status RunInvariantCheckers(HloT* hlo, absl::string_view after_pass_name);
+
+ // Helper which runs the given pass on the given HLO. HloT can be either
+ // HloModule or HloModuleGroup.
+ template <typename HloT>
+ StatusOr<bool> RunPassesInternal(HloT* hlo,
+ absl::Span<HloPassInterface* const> passes);
+
+ // Helpers which run the given passes on the given HLO construct. These
+ // helpers enable templating of the core of the pipeline logic by providing
+ // HloModule and HloModuleGroup specific methods with the same name.
+ static StatusOr<bool> RunHelper(HloPassInterface* pass, HloModule* module) {
+ return pass->Run(module);
+ }
+ static StatusOr<bool> RunHelper(HloPassInterface* pass,
+ HloModuleGroup* module_group) {
+ return pass->RunOnModuleGroup(module_group);
+ }
+
const string name_;
std::vector<std::unique_ptr<HloPassInterface>> passes_;
std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_;
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc
new file mode 100644
index 0000000000..ee8cb12b23
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc
@@ -0,0 +1,259 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+
+class HloPassPipelineTest : public HloVerifiedTestBase {
+ protected:
+ StatusOr<HloModuleGroup> ParseModuleGroup(
+ absl::Span<const string> hlo_strings) {
+ HloModuleGroup group(TestName());
+ for (const string& hlo_string : hlo_strings) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ group.push_back(std::move(module));
+ }
+ return std::move(group);
+ }
+};
+
+// A module pass which renames instructions named 'foo' to 'bar'.
+class FooToBarModulePass : public HloModulePass {
+ absl::string_view name() const override { return "foo2bar"; }
+
+ StatusOr<bool> Run(HloModule* module) override {
+ bool changed = false;
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->name() == "foo") {
+ instruction->SetAndSanitizeName("bar");
+ changed = true;
+ }
+ }
+ }
+ return changed;
+ }
+};
+
+// A module group pass which renames instructions named 'baz' to 'qux'.
+class BazToQuxModuleGroupPass : public HloModuleGroupPass {
+ absl::string_view name() const override { return "baz2qux"; }
+
+ StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override {
+ bool changed = false;
+ for (HloModule* module : module_group->modules()) {
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->name() == "baz") {
+ instruction->SetAndSanitizeName("qux");
+ changed = true;
+ }
+ }
+ }
+ }
+ return changed;
+ }
+};
+
+// An invariant checker pass which returns an error if there exists an
+// instruction named 'bar'.
+class BarBlowerUpper : public HloModulePass {
+ absl::string_view name() const override { return "bar-blower-upper"; }
+
+ StatusOr<bool> Run(HloModule* module) override {
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->name() == "bar") {
+ return InternalError("Module has instruction named bar");
+ }
+ }
+ }
+ return false;
+ }
+};
+
+TEST_F(HloPassPipelineTest, ModulePassChanged) {
+ // Test an HLO module pass which changes a module.
+ const string module_str = R"(
+HloModule ModulePassChanged
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT foo = f32[] multiply(a, b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddPass<FooToBarModulePass>();
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_EQ(root->name(), "foo");
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+ EXPECT_TRUE(changed);
+ EXPECT_EQ(root->name(), "bar");
+}
+
+TEST_F(HloPassPipelineTest, ModulePassUnchanged) {
+ // Test an HLO module pass which does not change a module.
+ const string module_str = R"(
+HloModule ModulePassUnchanged
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT blahblah = f32[] multiply(a, b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddPass<FooToBarModulePass>();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(HloPassPipelineTest, MixedPipeline) {
+ // Test a pipeline with both a module pass and a module group pass.
+ const string module_0_str = R"(
+HloModule MixedPipeline.1
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT baz = f32[] multiply(a, b)
+}
+)";
+ const string module_1_str = R"(
+HloModule MixedPipeline.0
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT foo = f32[] multiply(a, b)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup module_group,
+ ParseModuleGroup({module_0_str, module_1_str}));
+
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddPass<BazToQuxModuleGroupPass>();
+ pipeline.AddPass<FooToBarModulePass>();
+
+ HloInstruction* root0 =
+ module_group.module(0).entry_computation()->root_instruction();
+ HloInstruction* root1 =
+ module_group.module(1).entry_computation()->root_instruction();
+ EXPECT_EQ(root0->name(), "baz");
+ EXPECT_EQ(root1->name(), "foo");
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ pipeline.RunOnModuleGroup(&module_group));
+ EXPECT_TRUE(changed);
+
+ EXPECT_EQ(root0->name(), "qux");
+ EXPECT_EQ(root1->name(), "bar");
+}
+
+TEST_F(HloPassPipelineTest, InvariantChecker) {
+ const string module_str = R"(
+HloModule InvariantChecker
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT foo = f32[] multiply(a, b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
+ {
+ // Run a pipeline with just the invariant checker. It should not fail
+ // because there is no 'bar' instruction in the module.
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddInvariantChecker<BarBlowerUpper>();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+ EXPECT_FALSE(changed);
+ }
+
+ {
+ // Run a pipeline which renames 'foo' to 'bar' then an invariant checker
+ // which fails if there is an instruction named 'bar'.
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddInvariantChecker<BarBlowerUpper>();
+ pipeline.AddPass<FooToBarModulePass>();
+
+ Status status = pipeline.Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("Module has instruction named bar"));
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("Failed after foo2bar"));
+ }
+
+ {
+ // Run the invariant-checker only pipeline again. It should fail this time.
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddInvariantChecker<BarBlowerUpper>();
+
+ Status status = pipeline.Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("Module has instruction named bar"));
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("Failed after pipeline-start"));
+ }
+}
+
+TEST_F(HloPassPipelineTest, ModuleGroupPassOnModule) {
+ // Running a module group pass on a module should produce an error.
+ const string module_str = R"(
+HloModule ModuleGroupPassOnModule
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT foo = f32[] multiply(a, b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddPass<BazToQuxModuleGroupPass>();
+
+ Status status = pipeline.Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(
+ status.error_message(),
+ ::testing::HasSubstr("Module group pass cannot be run on a module"));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index bd6dd79b67..a438671936 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -1198,6 +1198,12 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module) {
<< HumanReadableNumBytes(memory_limit_bytes_);
XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
+ // Initialize pass object state.
+ computation_peak_memory_.clear();
+ rematerialized_computations_.clear();
+ instructions_rematerialized_ = 0;
+ net_instructions_added_ = 0;
+
TF_RET_CHECK(module->has_schedule());
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index e2aaf18b3e..7330d73c09 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -33,7 +33,7 @@ namespace xla {
// CSE will undo the effects of this optimization and should not be run after
// this pass. In general, this pass should be run very late, immediately before
// code generation.
-class HloRematerialization : public HloPassInterface {
+class HloRematerialization : public HloModulePass {
public:
using ShapeSizeFunction = std::function<int64(const Shape&)>;
diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
index d1cf644f82..fa34bddde1 100644
--- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
+++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
@@ -22,7 +22,7 @@ namespace xla {
// Unify subcomputations of a `HloModule`: if any computations are equal, choose
// one arbitrarily to use and delete the others.
-class HloSubcomputationUnification : public HloPassInterface {
+class HloSubcomputationUnification : public HloModulePass {
public:
absl::string_view name() const override {
return "subcomputation-unification";
diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc
index 773fc7d225..8549487702 100644
--- a/tensorflow/compiler/xla/service/hlo_value.cc
+++ b/tensorflow/compiler/xla/service/hlo_value.cc
@@ -131,6 +131,7 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
CHECK_LE(operand_number, 2);
return operand_number == 0 || index.empty();
+ case HloOpcode::kDomain:
case HloOpcode::kTuple:
// These instructions always pass through their operands transparently.
return false;
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 50f39cbcb5..6eb6658904 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -1057,6 +1057,7 @@ Status VerifySendsAndRecvs(const HloModule& module) {
} // namespace
StatusOr<bool> HloVerifier::Run(HloModule* module) {
+ TF_RET_CHECK(!module->name().empty());
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module));
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 42e3027bf1..0cde4a31af 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -151,7 +151,7 @@ class ShapeVerifier : public DfsHloVisitor {
// HLO pass that verifies invariants of HLO instructions for each computation in
// the module.
-class HloVerifier : public HloPassInterface {
+class HloVerifier : public HloModulePass {
public:
using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
index 85bb4a8b24..9c48b7db61 100644
--- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
+++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
@@ -25,7 +25,7 @@ namespace xla {
// Pass which replaces all implicit broadcasts with their equivalent sequence of
// explicit broadcast and reshape instructions.
-class ImplicitBroadcastRemover : public HloPassInterface {
+class ImplicitBroadcastRemover : public HloModulePass {
public:
ImplicitBroadcastRemover() {}
~ImplicitBroadcastRemover() override {}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index df9cbab915..3e238f97a0 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -366,7 +366,7 @@ class IndexedArrayAnalysis {
// A pass that prints all non-trivial results returned by IndexedArrayAnalysis.
// This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to
// unconditionally add to the regular HLO pass pipeline.
-class IndexedArrayAnalysisPrinterPass : public HloPassInterface {
+class IndexedArrayAnalysisPrinterPass : public HloModulePass {
public:
absl::string_view name() const override;
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/inliner.h
index efa8ed3abc..e20af08fb7 100644
--- a/tensorflow/compiler/xla/service/inliner.h
+++ b/tensorflow/compiler/xla/service/inliner.h
@@ -24,7 +24,7 @@ namespace xla {
// A pass which performs inlining. Which can result, for example, in functions
// that were previously being mapped by Map instead directly applied to the
// forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)).
-class Inliner : public HloPassInterface {
+class Inliner : public HloModulePass {
public:
~Inliner() override = default;
absl::string_view name() const override { return "inline"; }
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 3fdc2cee9a..e884122fcb 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -188,13 +188,20 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
bool InstructionFusion::CanFuseOnAllPaths(
HloInstruction* producer, HloInstruction* consumer,
- const HloInstructionSet& do_not_duplicate) {
+ const HloInstructionSet& do_not_fuse,
+ tensorflow::gtl::FlatMap<std::pair<HloInstruction*, HloInstruction*>, bool>*
+ result_cache) {
if (consumer == producer) {
return true;
}
if (!consumer->IsFusible()) {
return false;
}
+ auto cache_it = result_cache->find(std::make_pair(producer, consumer));
+ if (cache_it != result_cache->end()) {
+ return cache_it->second;
+ }
+ bool result = true;
for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) {
auto* consumer_operand = consumer->mutable_operand(i);
// If the operand is not on a path to the producer, it doesn't matter
@@ -202,20 +209,23 @@ bool InstructionFusion::CanFuseOnAllPaths(
if (!reachability_->IsReachable(producer, consumer_operand)) {
continue;
}
- if (do_not_duplicate.count(consumer_operand) > 0 ||
- !ShouldFuse(consumer, i)) {
- return false;
+ if (do_not_fuse.count(consumer_operand) > 0 || !ShouldFuse(consumer, i)) {
+ result = false;
+ break;
}
// The producer is reachable from consumer_operand which means we need
// to be able to fuse consumer_operand into consumer in order for
// producer to be fusible into consumer on all paths.
// Perform the recursive step: make sure producer can be fused into
// consumer_operand on all paths.
- if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_duplicate)) {
- return false;
+ if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_fuse,
+ result_cache)) {
+ result = false;
+ break;
}
}
- return true;
+ result_cache->emplace(std::make_pair(producer, consumer), result);
+ return result;
}
InstructionFusion::HloInstructionSet
@@ -231,6 +241,8 @@ InstructionFusion::ComputeGloballyUnfusible(
// fusing operations that require duplication later depending on
// is_expensive_().
HloInstructionSet do_not_duplicate;
+ tensorflow::gtl::FlatMap<std::pair<HloInstruction*, HloInstruction*>, bool>
+ can_fuse_on_all_paths_result_cache;
for (HloInstruction* consumer : post_order) {
for (HloInstruction* producer : consumer->operands()) {
if (do_not_duplicate.count(producer) > 0) {
@@ -286,7 +298,8 @@ InstructionFusion::ComputeGloballyUnfusible(
// A will be not allowed to be fused into B, as it cannot be fused via
// all paths.
if (producer->IsFusible() &&
- CanFuseOnAllPaths(producer, consumer, do_not_duplicate)) {
+ CanFuseOnAllPaths(producer, consumer, do_not_duplicate,
+ &can_fuse_on_all_paths_result_cache)) {
continue;
}
do_not_duplicate.insert(producer);
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index c1fde8ecfc..c1ec3b18a1 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -56,7 +56,7 @@ class FusionQueue {
// with the intent that the loops which compute their values will be fused in
// code generation. Derived classes define ShouldFuse method to select which
// instructions to fuse.
-class InstructionFusion : public HloPassInterface {
+class InstructionFusion : public HloModulePass {
public:
explicit InstructionFusion(
std::function<bool(const HloInstruction& instruction)> is_expensive,
@@ -151,8 +151,15 @@ class InstructionFusion : public HloPassInterface {
// Whether or not we can fuse producer into consumer on all paths
// from the producer to the consumer where nodes are HLOs and edges are uses.
- bool CanFuseOnAllPaths(HloInstruction* producer, HloInstruction* consumer,
- const HloInstructionSet& do_not_fuse);
+ //
+ // A map from <producer, consumer> to a bool is required as the result cache
+ // to store and query the results of calls to this function, in order to avoid
+ // repeated computations.
+ bool CanFuseOnAllPaths(
+ HloInstruction* producer, HloInstruction* consumer,
+ const HloInstructionSet& do_not_fuse,
+ tensorflow::gtl::FlatMap<std::pair<HloInstruction*, HloInstruction*>,
+ bool>* result_cache);
// Computes the set of nodes that we do not want to fuse into any of their
// consumers based on a global analysis of the HLO graph.
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index cf545031d3..e29c199c42 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -281,7 +281,7 @@ class ChannelLayoutConstraints {
// HLO pass which assigns layouts to all instructions in the HLO module while
// satisfying all necessary invariants and minimizing cost.
-class LayoutAssignment : public HloPassInterface {
+class LayoutAssignment : public HloModulePass {
public:
// entry_computation_layout is modified to populate a layout for the result in
// the case that no particular layout is requested.
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
index eaa09591b7..ec52a24d78 100644
--- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
+++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
@@ -54,7 +54,7 @@ Status LogicalBufferAnalysis::Analyze() {
// so reserve 10% more than the number of instructions to avoid frequent
// resizes.
logical_buffers_.clear();
- logical_buffers_.reserve((module_->NumUniqueInstructionIds() * 11) / 10);
+ logical_buffers_.reserve((module_->instruction_count() * 11) / 10);
// We filter out fusion computations, and get to them through fusion
// instructions. This is because it's possible to have orphaned (unreachable)
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h
index d2c52651c4..0344626b26 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.h
@@ -44,7 +44,7 @@ namespace xla {
// Note that the reachability map is updated based on the original computation.
// This works because the reachability is monotonically increasing with
// instruction fusion.
-class MultiOutputFusion : public HloPassInterface {
+class MultiOutputFusion : public HloModulePass {
public:
MultiOutputFusion(int64 fuel) : fuel_(fuel) {}
diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc
index bd8fb17a23..ac2f79674f 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.cc
+++ b/tensorflow/compiler/xla/service/name_uniquer.cc
@@ -39,8 +39,10 @@ NameUniquer::NameUniquer(const string& separator) {
}
/*static*/ string NameUniquer::GetSanitizedName(const string& name) {
+ if (name.empty()) {
+ return "";
+ }
string result = name;
- CHECK(!result.empty()) << "name should not be empty";
char c = static_cast<unsigned char>(result[0]);
if (!isalpha(c) && c != '_') {
result[0] = '_';
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h
index 4869db79e7..380cde0e6a 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher.h
+++ b/tensorflow/compiler/xla/service/pattern_matcher.h
@@ -17,8 +17,12 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
#include "absl/strings/string_view.h"
+#include "absl/utility/utility.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -116,15 +120,82 @@ namespace xla {
// .WithOperand(1, Op(&c))
// .WithOperand(2, Op(&d))
//
+
+struct MatchOption {
+ // If true, actually capture matched item into the user pointer.
+ bool capture;
+};
+
template <typename Value, typename Pattern>
-bool Match(Value* value, const Pattern& pattern) {
- return pattern.Match(value);
+bool Match(Value* value, const Pattern& pattern,
+ MatchOption option = {/*.capture=*/true}) {
+ if (option.capture) {
+ auto new_option = option;
+ new_option.capture = false;
+ if (!pattern.Match(value, new_option)) {
+ return false;
+ }
+ }
+ return pattern.Match(value, option);
}
namespace match {
namespace detail {
+template <typename Item, typename... Patterns>
+class AllOfPattern {
+ public:
+ explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
+
+ bool Match(const Item* item, MatchOption option) const {
+ bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
+ // This invariant is guaranteed by the top-level Match and AnyOf.
+ DCHECK(matched || !option.capture);
+ return matched;
+ }
+
+ bool Match(Item* item, MatchOption option) const {
+ bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
+ // This invariant is guaranteed by the top-level Match and AnyOf.
+ DCHECK(matched || !option.capture);
+ return matched;
+ }
+
+ private:
+ template <typename ItemType, size_t index>
+ bool MatchImpl(ItemType* item, MatchOption option,
+ std::integral_constant<size_t, index>) const {
+ return std::get<index>(patterns_).Match(item, option) &&
+ MatchImpl(item, option, std::integral_constant<size_t, index + 1>());
+ }
+
+ template <typename ItemType>
+ bool MatchImpl(ItemType* item, MatchOption option,
+ std::integral_constant<size_t, sizeof...(Patterns)>) const {
+ return true;
+ }
+
+ std::tuple<Patterns...> patterns_;
+};
+
+} // namespace detail
+
+// Returns a pattern that represents the conjunction of all input patterns. All
+// patterns need to match in order to have the AllOf pattern match.
+//
+// TODO(timshen): Currently AllOf is still nested, e.g. AllOf<AllOf<A>, B> is
+// not AllOf<A, B>. We might want to flatten the AllOf type structure if the
+// C++ compile error message gets annoying.
+template <typename Item, typename... Patterns>
+detail::AllOfPattern<typename std::remove_const<Item>::type, Patterns...> AllOf(
+ const Patterns&... patterns) {
+ return detail::AllOfPattern<typename std::remove_const<Item>::type,
+ Patterns...>(patterns...);
+}
+
+namespace detail {
+
template <typename LayoutType, typename Impl>
class LayoutPattern;
@@ -132,57 +203,61 @@ class LayoutPattern;
// nullptr.
class LayoutPatternBaseImpl {
public:
- bool Match(const ::xla::Layout* layout) const { return layout != nullptr; }
+ bool Match(const ::xla::Layout* layout, MatchOption option) const {
+ return layout != nullptr;
+ }
};
// A LayoutPattern implementation that matches only if the layout equals a
// Layout proto.
-template <typename Previous>
class LayoutPatternEqualImpl {
public:
- explicit constexpr LayoutPatternEqualImpl(const Previous& previous,
- const ::xla::Layout* layout)
- : previous_(previous), layout_(layout) {}
+ explicit constexpr LayoutPatternEqualImpl(const ::xla::Layout* layout)
+ : layout_(layout) {}
- bool Match(const ::xla::Layout* layout) const {
- return previous_.Match(layout) && LayoutUtil::Equal(*layout_, *layout);
+ bool Match(const ::xla::Layout* layout, MatchOption option) const {
+ return LayoutUtil::Equal(*layout_, *layout);
}
private:
- Previous previous_;
const ::xla::Layout* layout_;
};
// A LayoutPattern implementation that matches only if the layout has a given
// format.
-template <typename Previous>
class LayoutPatternFormatImpl {
public:
- explicit constexpr LayoutPatternFormatImpl(const Previous& previous,
- Format format)
- : previous_(previous), format_(format) {}
+ explicit constexpr LayoutPatternFormatImpl(Format format) : format_(format) {}
- bool Match(const ::xla::Layout* layout) const {
- return previous_.Match(layout) && layout->format() == format_;
+ bool Match(const ::xla::Layout* layout, MatchOption option) const {
+ return layout->format() == format_;
}
private:
- Previous previous_;
Format format_;
};
// A pattern that matches Layouts.
template <typename LayoutType, typename Impl>
class LayoutPattern {
+ private:
+ template <typename NewImpl>
+ LayoutPattern<LayoutType, AllOfPattern<::xla::Layout, Impl, NewImpl>>
+ AppendImpl(NewImpl new_impl) const {
+ return LayoutPattern<LayoutType,
+ AllOfPattern<::xla::Layout, Impl, NewImpl>>(
+ AllOf<Layout>(impl_, std::move(new_impl)), matched_layout_);
+ }
+
public:
explicit constexpr LayoutPattern(const Impl& impl,
LayoutType** matched_layout)
: impl_(impl), matched_layout_(matched_layout) {}
// Returns true and captures the layout iff it matches the pattern.
- bool Match(const ::xla::Layout* layout) const {
- if (impl_.Match(layout)) {
- if (matched_layout_) {
+ bool Match(const ::xla::Layout* layout, MatchOption option) const {
+ if (impl_.Match(layout, option)) {
+ if (option.capture && matched_layout_) {
*matched_layout_ = layout;
}
return true;
@@ -191,9 +266,9 @@ class LayoutPattern {
}
// Returns true and captures the layout iff it matches the pattern.
- bool Match(::xla::Layout* layout) const {
- if (impl_.Match(layout)) {
- if (matched_layout_) {
+ bool Match(::xla::Layout* layout, MatchOption option) const {
+ if (impl_.Match(layout, option)) {
+ if (option.capture && matched_layout_) {
*matched_layout_ = layout;
}
return true;
@@ -203,24 +278,21 @@ class LayoutPattern {
// Modifies the pattern to match only if the layout equals the given proto.
// The layout must outlive the returned pattern.
- constexpr LayoutPattern<LayoutType, LayoutPatternEqualImpl<Impl>> EqualTo(
- const ::xla::Layout* layout) const {
- return LayoutPattern<LayoutType, LayoutPatternEqualImpl<Impl>>(
- LayoutPatternEqualImpl<Impl>(impl_, layout), matched_layout_);
+ constexpr auto EqualTo(const ::xla::Layout* layout) const
+ -> decltype(this->AppendImpl(LayoutPatternEqualImpl(layout))) {
+ return AppendImpl(LayoutPatternEqualImpl(layout));
}
// Modifies the pattern to match only if the layout has a dense format.
- constexpr LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>>
- WithDenseFormat() const {
- return LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>>(
- LayoutPatternFormatImpl<Impl>(impl_, DENSE), matched_layout_);
+ constexpr auto WithDenseFormat() const
+ -> decltype(this->AppendImpl(LayoutPatternFormatImpl(DENSE))) {
+ return AppendImpl(LayoutPatternFormatImpl(DENSE));
}
// Modifies the pattern to match only if the layout has a sparse format.
- constexpr LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>>
- WithSparseFormat() const {
- return LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>>(
- LayoutPatternFormatImpl<Impl>(impl_, SPARSE), matched_layout_);
+ constexpr auto WithSparseFormat() const
+ -> decltype(this->AppendImpl(LayoutPatternFormatImpl(SPARSE))) {
+ return AppendImpl(LayoutPatternFormatImpl(SPARSE));
}
private:
@@ -228,8 +300,72 @@ class LayoutPattern {
LayoutType** matched_layout_;
};
+template <typename Item, typename... Patterns>
+class AnyOfPattern {
+ public:
+ explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
+
+ bool Match(const Item* item, MatchOption option) const {
+ return MatchImpl(item, option, std::integral_constant<size_t, 0>());
+ }
+
+ bool Match(Item* item, MatchOption option) const {
+ return MatchImpl(item, option, std::integral_constant<size_t, 0>());
+ }
+
+ private:
+ template <typename ItemType, size_t index>
+ bool MatchImpl(ItemType* item, MatchOption option,
+ std::integral_constant<size_t, index>) const {
+ auto new_option = option;
+ new_option.capture = false;
+ // Try to match the sub-pattern without capturing behavior.
+ if (std::get<index>(patterns_).Match(item, new_option)) {
+ // Capture the branch.
+ if (option.capture) {
+ // TODO(timshen): Currently the behavior can be exponential. Optimize it
+ // with memoization or recording the matched sub-pattern index, if it
+ // takes too long to run.
+ //
+ // Specifically, the "memoization" approach is to create an empty
+ // container with the key (pattern, instruction), and value as whether
+ // matched or not.
+ //
+ // Alternatively, we may run the pattern matching with captures off, but
+ // instead record a "trace" somewhere, indicating how exactly the
+ // pattern matches the input. For example, the trace information for
+ // AnyOf will be a runtime number indicate which sub-pattern is matched.
+ // Then we run another pass to do captures only with the help of the
+ // trace.
+ bool ret = std::get<index>(patterns_).Match(item, option);
+ DCHECK(ret);
+ }
+ return true;
+ }
+ return MatchImpl(item, option, std::integral_constant<size_t, index + 1>());
+ }
+
+ template <typename ItemType>
+ bool MatchImpl(ItemType* item, MatchOption option,
+ std::integral_constant<size_t, sizeof...(Patterns)>) const {
+ return false;
+ }
+
+ std::tuple<Patterns...> patterns_;
+};
+
} // namespace detail
+// Returns a pattern that represents the logical disjunction of the input
+// patterns. The returned pattern matches from left to right, and stops on the
+// first match.
+template <typename Item, typename... Patterns>
+detail::AnyOfPattern<typename std::remove_const<Item>::type, Patterns...> AnyOf(
+ const Patterns&... patterns) {
+ return detail::AnyOfPattern<typename std::remove_const<Item>::type,
+ Patterns...>(patterns...);
+}
+
// Creates a layout pattern that will capture the matched layout in the
// argument.
inline constexpr detail::LayoutPattern<const ::xla::Layout,
@@ -258,172 +394,145 @@ class ShapePattern;
// nullptr.
class ShapePatternBaseImpl {
public:
- bool Match(const ::xla::Shape* shape) const { return shape != nullptr; }
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return shape != nullptr;
+ }
};
// A ShapePattern implementation that matches only if the shape equals a Shape
// proto.
-template <typename Previous>
class ShapePatternEqualImpl {
public:
- explicit constexpr ShapePatternEqualImpl(const Previous& previous,
- const ::xla::Shape* shape)
- : previous_(previous), shape_(shape) {}
+ explicit constexpr ShapePatternEqualImpl(const ::xla::Shape* shape)
+ : shape_(shape) {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::Equal(*shape_, *shape);
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::Equal(*shape_, *shape);
}
private:
- Previous previous_;
const ::xla::Shape* shape_;
};
// A ShapePattern implementation that matches only if the shape is compatible to
// a Shape proto.
-template <typename Previous>
class ShapePatternCompatibleImpl {
public:
- explicit constexpr ShapePatternCompatibleImpl(const Previous& previous,
- const ::xla::Shape* shape)
- : previous_(previous), shape_(shape) {}
+ explicit constexpr ShapePatternCompatibleImpl(const ::xla::Shape* shape)
+ : shape_(shape) {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::Compatible(*shape_, *shape);
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::Compatible(*shape_, *shape);
}
private:
- Previous previous_;
const ::xla::Shape* shape_;
};
// A ShapePattern implementation that matches only if the shape has a given
// element type.
-template <typename Previous>
class ShapePatternElementTypeImpl {
public:
- explicit constexpr ShapePatternElementTypeImpl(const Previous& previous,
- PrimitiveType element_type)
- : previous_(previous), element_type_(element_type) {}
+ explicit constexpr ShapePatternElementTypeImpl(PrimitiveType element_type)
+ : element_type_(element_type) {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && shape->element_type() == element_type_;
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return shape->element_type() == element_type_;
}
private:
- Previous previous_;
PrimitiveType element_type_;
};
// A ShapePattern implementation that matches only if the shape is scalar.
-template <typename Previous>
class ShapePatternIsScalarImpl {
public:
- explicit constexpr ShapePatternIsScalarImpl(const Previous& previous)
- : previous_(previous) {}
+ explicit constexpr ShapePatternIsScalarImpl() {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::IsScalar(*shape);
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::IsScalar(*shape);
}
-
- private:
- Previous previous_;
};
// A ShapePattern implementation that matches only if the shape is an array
-template <typename Previous>
class ShapePatternIsArrayImpl {
public:
- explicit constexpr ShapePatternIsArrayImpl(const Previous& previous)
- : previous_(previous) {}
+ explicit constexpr ShapePatternIsArrayImpl() {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::IsArray(*shape);
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::IsArray(*shape);
}
-
- private:
- Previous previous_;
};
// A ShapePattern implementation that matches only if the shape is a tuple.
-template <typename Previous>
class ShapePatternIsTupleImpl {
public:
- explicit constexpr ShapePatternIsTupleImpl(const Previous& previous)
- : previous_(previous) {}
+ explicit constexpr ShapePatternIsTupleImpl() {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::IsTuple(*shape);
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::IsTuple(*shape);
}
-
- private:
- Previous previous_;
};
// A ShapePattern implementation that matches only if the shape has a given
// rank.
-template <typename Previous>
class ShapePatternRankImpl {
public:
- explicit constexpr ShapePatternRankImpl(const Previous& previous, int64 rank)
- : previous_(previous), rank_(rank) {}
+ explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::Rank(*shape) == rank_;
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::Rank(*shape) == rank_;
}
private:
- Previous previous_;
int64 rank_;
};
// A ShapePattern implementation that matches only if the shape has a layout
// that matches a given pattern.
-template <typename Previous, typename LayoutType, typename LayoutImpl>
+template <typename LayoutType, typename LayoutImpl>
class ShapePatternLayoutImpl {
public:
explicit constexpr ShapePatternLayoutImpl(
- const Previous& previous,
const LayoutPattern<LayoutType, LayoutImpl>& layout)
- : previous_(previous), layout_(layout) {}
+ : layout_(layout) {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) &&
- layout_.Match(&shape->layout());
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return LayoutUtil::HasLayout(*shape) &&
+ layout_.Match(&shape->layout(), option);
}
- bool Match(Shape* shape) const {
- return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) &&
- layout_.Match(shape->mutable_layout());
+ bool Match(Shape* shape, MatchOption option) const {
+ return LayoutUtil::HasLayout(*shape) &&
+ layout_.Match(shape->mutable_layout(), option);
}
private:
- Previous previous_;
LayoutPattern<LayoutType, LayoutImpl> layout_;
};
// A ShapePattern implementation that matches only if the shape has a subshape
// that matches a given pattern.
-template <typename Previous, typename SubshapeType, typename SubshapeImpl>
+template <typename SubshapeType, typename SubshapeImpl>
class ShapePatternSubshapeImpl {
public:
explicit ShapePatternSubshapeImpl(
- const Previous& previous, ShapeIndexView index,
+ ShapeIndexView index,
const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
- : previous_(previous), index_(index), subshape_(subshape) {}
+ : index_(index), subshape_(subshape) {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) &&
- subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_));
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::IndexIsValid(*shape, index_) &&
+ subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_), option);
}
- bool Match(::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) &&
- subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_));
+ bool Match(::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::IndexIsValid(*shape, index_) &&
+ subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_),
+ option);
}
private:
- Previous previous_;
ShapeIndexView index_;
ShapePattern<SubshapeType, SubshapeImpl> subshape_;
};
@@ -431,14 +540,22 @@ class ShapePatternSubshapeImpl {
// A pattern that matches Shapes.
template <typename ShapeType, typename Impl>
class ShapePattern {
+ private:
+ template <typename NewImpl>
+ ShapePattern<ShapeType, AllOfPattern<::xla::Shape, Impl, NewImpl>> AppendImpl(
+ NewImpl new_impl) const {
+ return ShapePattern<ShapeType, AllOfPattern<::xla::Shape, Impl, NewImpl>>(
+ AllOf<Shape>(impl_, std::move(new_impl)), matched_shape_);
+ }
+
public:
explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape)
: impl_(impl), matched_shape_(matched_shape) {}
// Returns true and captures the shape iff it matches the pattern.
- bool Match(const ::xla::Shape* shape) const {
- if (impl_.Match(shape)) {
- if (matched_shape_) {
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ if (impl_.Match(shape, option)) {
+ if (option.capture && matched_shape_) {
*matched_shape_ = shape;
}
return true;
@@ -447,9 +564,9 @@ class ShapePattern {
}
// Returns true and captures the shape iff it matches the pattern.
- bool Match(::xla::Shape* shape) const {
- if (impl_.Match(shape)) {
- if (matched_shape_) {
+ bool Match(::xla::Shape* shape, MatchOption option) const {
+ if (impl_.Match(shape, option)) {
+ if (option.capture && matched_shape_) {
*matched_shape_ = shape;
}
return true;
@@ -459,108 +576,90 @@ class ShapePattern {
// Modifies the pattern to match only if the shape equals the given proto.
// The layout must outlive the returned pattern.
- constexpr ShapePattern<ShapeType, ShapePatternEqualImpl<Impl>> EqualTo(
- const ::xla::Shape* shape) const {
- return ShapePattern<ShapeType, ShapePatternEqualImpl<Impl>>(
- ShapePatternEqualImpl<Impl>(impl_, shape), matched_shape_);
+ constexpr auto EqualTo(const ::xla::Shape* shape) const
+ -> decltype(this->AppendImpl(ShapePatternEqualImpl(shape))) {
+ return AppendImpl(ShapePatternEqualImpl(shape));
}
// Modifies the pattern to match only if the shape is compatible to the given
// proto. The layout must outlive the returned pattern.
- constexpr ShapePattern<ShapeType, ShapePatternCompatibleImpl<Impl>>
- CompatibleTo(const ::xla::Shape* shape) const {
- return ShapePattern<ShapeType, ShapePatternCompatibleImpl<Impl>>(
- ShapePatternCompatibleImpl<Impl>(impl_, shape), matched_shape_);
+ constexpr auto CompatibleTo(const ::xla::Shape* shape) const
+ -> decltype(this->AppendImpl(ShapePatternCompatibleImpl(shape))) {
+ return AppendImpl(ShapePatternCompatibleImpl(shape));
}
// Modifies the pattern to match only if the shape has the given element type.
- constexpr ShapePattern<ShapeType, ShapePatternElementTypeImpl<Impl>>
- WithElementType(PrimitiveType element_type) const {
- return ShapePattern<ShapeType, ShapePatternElementTypeImpl<Impl>>(
- ShapePatternElementTypeImpl<Impl>(impl_, element_type), matched_shape_);
+ constexpr auto WithElementType(PrimitiveType element_type) const
+ -> decltype(this->AppendImpl(ShapePatternElementTypeImpl(element_type))) {
+ return AppendImpl(ShapePatternElementTypeImpl(element_type));
}
// Modifies the pattern to match only if the shape is scalar.
- constexpr ShapePattern<ShapeType, ShapePatternIsScalarImpl<Impl>> IsScalar()
- const {
- return ShapePattern<ShapeType, ShapePatternIsScalarImpl<Impl>>(
- ShapePatternIsScalarImpl<Impl>(impl_), matched_shape_);
+ constexpr auto IsScalar() const
+ -> decltype(this->AppendImpl(ShapePatternIsScalarImpl())) {
+ return AppendImpl(ShapePatternIsScalarImpl());
}
// Modifies the pattern to match only if the shape is an array.
- constexpr ShapePattern<ShapeType, ShapePatternIsArrayImpl<Impl>> IsArray()
- const {
- return ShapePattern<ShapeType, ShapePatternIsArrayImpl<Impl>>(
- ShapePatternIsArrayImpl<Impl>(impl_), matched_shape_);
+ constexpr auto IsArray() const
+ -> decltype(this->AppendImpl(ShapePatternIsArrayImpl())) {
+ return AppendImpl(ShapePatternIsArrayImpl());
}
// Modifies the pattern to match only if the shape is a tuple.
- constexpr ShapePattern<ShapeType, ShapePatternIsTupleImpl<Impl>> IsTuple()
- const {
- return ShapePattern<ShapeType, ShapePatternIsTupleImpl<Impl>>(
- ShapePatternIsTupleImpl<Impl>(impl_), matched_shape_);
+ constexpr auto IsTuple() const
+ -> decltype(this->AppendImpl(ShapePatternIsTupleImpl())) {
+ return AppendImpl(ShapePatternIsTupleImpl());
}
// Modifies the pattern to match only if the shape has the given rank.
- constexpr ShapePattern<ShapeType, ShapePatternRankImpl<Impl>> WithRank(
- int64 rank) const {
- return ShapePattern<ShapeType, ShapePatternRankImpl<Impl>>(
- ShapePatternRankImpl<Impl>(impl_, rank), matched_shape_);
+ constexpr auto WithRank(int64 rank) const
+ -> decltype(this->AppendImpl(ShapePatternRankImpl(rank))) {
+ return AppendImpl(ShapePatternRankImpl(rank));
}
// Modifies the pattern to match only if the shape has a layout that matches
// the given pattern.
template <typename LayoutType, typename LayoutImpl>
- constexpr ShapePattern<ShapeType,
- ShapePatternLayoutImpl<Impl, LayoutType, LayoutImpl>>
- WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const {
- return ShapePattern<ShapeType,
- ShapePatternLayoutImpl<Impl, LayoutType, LayoutImpl>>(
- ShapePatternLayoutImpl<Impl, LayoutType, LayoutImpl>(impl_, layout),
- matched_shape_);
- }
-
- constexpr ShapePattern<
- ShapeType,
- ShapePatternLayoutImpl<Impl, const ::xla::Layout,
- LayoutPatternEqualImpl<LayoutPatternBaseImpl>>>
- WithLayoutEqualTo(const ::xla::Layout* layout) const {
+ auto WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const
+ -> decltype(this->AppendImpl(
+ ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout))) {
+ return AppendImpl(ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout));
+ }
+
+ constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const
+ -> decltype(this->WithLayout(Layout().EqualTo(layout))) {
return WithLayout(Layout().EqualTo(layout));
}
- constexpr ShapePattern<
- ShapeType,
- ShapePatternLayoutImpl<Impl, const ::xla::Layout,
- LayoutPatternFormatImpl<LayoutPatternBaseImpl>>>
- IsDenseArray() const {
+ constexpr auto IsDenseArray() const
+ -> decltype(this->WithLayout(Layout().WithDenseFormat())) {
return WithLayout(Layout().WithDenseFormat());
}
- constexpr ShapePattern<
- ShapeType,
- ShapePatternLayoutImpl<Impl, const ::xla::Layout,
- LayoutPatternFormatImpl<LayoutPatternBaseImpl>>>
- IsSparseArray() const {
+ constexpr auto IsSparseArray() const
+ -> decltype(this->WithLayout(Layout().WithSparseFormat())) {
return WithLayout(Layout().WithSparseFormat());
}
// Modifies the pattern to match only if the shape has a subshape that matches
// the given pattern.
template <typename SubshapeType, typename SubshapeImpl>
+ auto WithSubshape(ShapeIndexView index,
+ const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
+ const -> decltype(this->AppendImpl(
+ ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index,
+ subshape))) {
+ return AppendImpl(
+ ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index, subshape));
+ }
+
ShapePattern<ShapeType,
- ShapePatternSubshapeImpl<Impl, SubshapeType, SubshapeImpl>>
- WithSubshape(ShapeIndexView index,
- const ShapePattern<SubshapeType, SubshapeImpl>& subshape) const {
- return ShapePattern<
- ShapeType, ShapePatternSubshapeImpl<Impl, SubshapeType, SubshapeImpl>>(
- ShapePatternSubshapeImpl<Impl, SubshapeType, SubshapeImpl>(impl_, index,
- subshape),
- matched_shape_);
- }
-
- ShapePattern<ShapeType, ShapePatternSubshapeImpl<
- Impl, const ::xla::Shape,
- ShapePatternEqualImpl<ShapePatternBaseImpl>>>
+ AllOfPattern<Shape, Impl,
+ ShapePatternSubshapeImpl<
+ const ::xla::Shape,
+ AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
+ ShapePatternEqualImpl>>>>
WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const {
return WithSubshape(index,
ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
@@ -568,9 +667,12 @@ class ShapePattern {
.EqualTo(shape));
}
- ShapePattern<ShapeType, ShapePatternSubshapeImpl<
- Impl, const ::xla::Shape,
- ShapePatternCompatibleImpl<ShapePatternBaseImpl>>>
+ ShapePattern<ShapeType,
+ AllOfPattern<Shape, Impl,
+ ShapePatternSubshapeImpl<
+ const ::xla::Shape,
+ AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
+ ShapePatternCompatibleImpl>>>>
WithSubshapeCompatibleTo(ShapeIndexView index,
const ::xla::Shape* shape) const {
return WithSubshape(index,
@@ -611,159 +713,169 @@ class HloInstructionPattern;
// instruction is not nullptr.
class HloInstructionPatternBaseImpl {
public:
- bool Match(const ::xla::HloInstruction* inst) const {
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
return inst != nullptr;
}
};
// An HloInstructionPattern implementation that matches only if the instruction
// has a given name.
-template <typename Previous>
class HloInstructionPatternNameImpl {
public:
- explicit HloInstructionPatternNameImpl(const Previous& previous,
- absl::string_view name)
- : previous_(previous), name_(name) {}
+ explicit HloInstructionPatternNameImpl(absl::string_view name)
+ : name_(name) {}
- bool Match(const ::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && inst->name() == name_;
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ return inst->name() == name_;
}
private:
- Previous previous_;
absl::string_view name_;
};
// An HloInstructionPattern implementation that matches only if the instruction
// has a given opcode.
-template <typename Previous>
class HloInstructionPatternOpcodeImpl {
public:
- explicit constexpr HloInstructionPatternOpcodeImpl(const Previous& previous,
- HloOpcode opcode,
+ explicit constexpr HloInstructionPatternOpcodeImpl(HloOpcode opcode,
bool invert)
- : previous_(previous), opcode_(opcode), invert_(invert) {}
+ : opcode_(opcode), invert_(invert) {}
- bool Match(const ::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && (invert_ ^ (inst->opcode() == opcode_));
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ return (invert_ ^ (inst->opcode() == opcode_));
}
private:
- Previous previous_;
HloOpcode opcode_;
bool invert_;
};
// An HloInstructionPattern implementation that matches only if the instruction
// has a shape that matches a given pattern.
-template <typename Previous, typename ShapeType, typename ShapeImpl>
+template <typename ShapeType, typename ShapeImpl>
class HloInstructionPatternShapeImpl {
public:
explicit constexpr HloInstructionPatternShapeImpl(
- const Previous& previous, const ShapePattern<ShapeType, ShapeImpl>& shape)
- : previous_(previous), shape_(shape) {}
+ const ShapePattern<ShapeType, ShapeImpl>& shape)
+ : shape_(shape) {}
- bool Match(const ::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && shape_.Match(&inst->shape());
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ return shape_.Match(&inst->shape(), option);
}
- bool Match(::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && shape_.Match(inst->mutable_shape());
+ bool Match(::xla::HloInstruction* inst, MatchOption option) const {
+ return shape_.Match(inst->mutable_shape(), option);
}
private:
- Previous previous_;
ShapePattern<ShapeType, ShapeImpl> shape_;
};
// An HloInstructionPattern implementation that matches only if the instruction
// has an operand that matches a given pattern.
-template <typename Previous, typename OperandType, typename OperandImpl>
+template <typename OperandType, typename OperandImpl>
class HloInstructionPatternOperandImpl {
public:
explicit constexpr HloInstructionPatternOperandImpl(
- const Previous& previous, int64 operand_index,
+ int64 operand_index,
const HloInstructionPattern<OperandType, OperandImpl>& operand)
- : previous_(previous), operand_index_(operand_index), operand_(operand) {}
+ : operand_index_(operand_index), operand_(operand) {}
- bool Match(const ::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && operand_index_ < inst->operand_count() &&
- operand_.Match(inst->operand(operand_index_));
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ return operand_index_ < inst->operand_count() &&
+ operand_.Match(inst->operand(operand_index_), option);
}
- bool Match(::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && operand_index_ < inst->operand_count() &&
- operand_.Match(inst->mutable_operand(operand_index_));
+ bool Match(::xla::HloInstruction* inst, MatchOption option) const {
+ return operand_index_ < inst->operand_count() &&
+ operand_.Match(inst->mutable_operand(operand_index_), option);
}
private:
- Previous previous_;
int64 operand_index_;
HloInstructionPattern<OperandType, OperandImpl> operand_;
};
// An HloInstructionPattern implementation that matches only if the instruction
// is a fusion node with a particular kind.
-template <typename Previous>
class HloInstructionPatternFusionKindImpl {
public:
explicit constexpr HloInstructionPatternFusionKindImpl(
- const Previous& previous, ::xla::HloInstruction::FusionKind kind)
- : previous_(previous), kind_(kind) {}
+ ::xla::HloInstruction::FusionKind kind)
+ : kind_(kind) {}
- bool Match(const ::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion &&
- inst->fusion_kind() == kind_;
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ return inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == kind_;
}
- bool Match(::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion &&
- inst->fusion_kind() == kind_;
+ bool Match(::xla::HloInstruction* inst, MatchOption option) const {
+ return inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == kind_;
}
private:
- Previous previous_;
::xla::HloInstruction::FusionKind kind_;
};
// An HloInstructionPattern implementation that matches only if the instruction
// is a kGetTupleElement with a particular tuple index.
-template <typename Previous>
class HloInstructionPatternTupleIndexImpl {
public:
- explicit constexpr HloInstructionPatternTupleIndexImpl(
- const Previous& previous, int64 tuple_index)
- : previous_(previous), tuple_index_(tuple_index) {}
+ explicit constexpr HloInstructionPatternTupleIndexImpl(int64 tuple_index)
+ : tuple_index_(tuple_index) {}
- bool Match(const ::xla::HloInstruction* inst) const {
- return previous_.Match(inst) &&
- inst->opcode() == HloOpcode::kGetTupleElement &&
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ return inst->opcode() == HloOpcode::kGetTupleElement &&
inst->tuple_index() == tuple_index_;
}
- bool Match(::xla::HloInstruction* inst) const {
- return previous_.Match(inst) &&
- inst->opcode() == HloOpcode::kGetTupleElement &&
+ bool Match(::xla::HloInstruction* inst, MatchOption option) const {
+ return inst->opcode() == HloOpcode::kGetTupleElement &&
inst->tuple_index() == tuple_index_;
}
private:
- Previous previous_;
int64 tuple_index_;
};
+template <typename ItemType, typename Predicate>
+class HloPredicatePatternImpl {
+ public:
+ explicit HloPredicatePatternImpl(Predicate pred) : pred_(std::move(pred)) {}
+
+ bool Match(const ItemType* item, MatchOption option) const {
+ return pred_(item);
+ }
+
+ bool Match(ItemType* item, MatchOption option) const { return pred_(item); }
+
+ private:
+ Predicate pred_;
+};
+
+struct PatternFriend;
+
// A pattern that matches HloInstructions.
template <typename HloInstructionType, typename Impl>
class HloInstructionPattern {
+ private:
+ template <typename NewImpl>
+ HloInstructionPattern<HloInstructionType,
+ AllOfPattern<::xla::HloInstruction, Impl, NewImpl>>
+ AppendImpl(NewImpl new_impl) const {
+ return HloInstructionPattern<
+ HloInstructionType, AllOfPattern<::xla::HloInstruction, Impl, NewImpl>>(
+ AllOf<HloInstruction>(impl_, std::move(new_impl)), matched_inst_);
+ }
+
public:
explicit constexpr HloInstructionPattern(const Impl& impl,
HloInstructionType** matched_inst)
: impl_(impl), matched_inst_(matched_inst) {}
// Returns true and captures the instruction iff it matches the pattern.
- bool Match(const ::xla::HloInstruction* inst) const {
- if (impl_.Match(inst)) {
- if (matched_inst_) {
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ if (impl_.Match(inst, option)) {
+ if (option.capture && matched_inst_) {
*matched_inst_ = inst;
}
return true;
@@ -772,9 +884,9 @@ class HloInstructionPattern {
}
// Returns true and captures the instruction iff it matches the pattern.
- bool Match(::xla::HloInstruction* inst) const {
- if (impl_.Match(inst)) {
- if (matched_inst_) {
+ bool Match(::xla::HloInstruction* inst, MatchOption option) const {
+ if (impl_.Match(inst, option)) {
+ if (option.capture && matched_inst_) {
*matched_inst_ = inst;
}
return true;
@@ -783,102 +895,87 @@ class HloInstructionPattern {
}
// Modifies the pattern to match only if the instruction has the given name.
- HloInstructionPattern<HloInstructionType, HloInstructionPatternNameImpl<Impl>>
- WithName(absl::string_view name) const {
- return HloInstructionPattern<HloInstructionType,
- HloInstructionPatternNameImpl<Impl>>(
- HloInstructionPatternNameImpl<Impl>(impl_, name), matched_inst_);
+ auto WithName(absl::string_view name) const
+ -> decltype(this->AppendImpl(HloInstructionPatternNameImpl(name))) {
+ return AppendImpl(HloInstructionPatternNameImpl(name));
}
// Modifies the pattern to match only if the instruction has the given opcode.
- constexpr HloInstructionPattern<HloInstructionType,
- HloInstructionPatternOpcodeImpl<Impl>>
- WithOpcode(HloOpcode opcode) const {
- return HloInstructionPattern<HloInstructionType,
- HloInstructionPatternOpcodeImpl<Impl>>(
- HloInstructionPatternOpcodeImpl<Impl>(impl_, opcode, false),
- matched_inst_);
+ auto WithOpcode(HloOpcode opcode) const
+ -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode,
+ false))) {
+ return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false));
}
// Modifies the pattern to match only if the instruction does not have the
// given opcode.
- constexpr HloInstructionPattern<HloInstructionType,
- HloInstructionPatternOpcodeImpl<Impl>>
- WithoutOpcode(HloOpcode opcode) const {
- return HloInstructionPattern<HloInstructionType,
- HloInstructionPatternOpcodeImpl<Impl>>(
- HloInstructionPatternOpcodeImpl<Impl>(impl_, opcode, true),
- matched_inst_);
+ auto WithoutOpcode(HloOpcode opcode) const
+ -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode,
+ true))) {
+ return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true));
}
// Modifies the pattern to match only if the instruction is a constant.
- constexpr HloInstructionPattern<HloInstructionType,
- HloInstructionPatternOpcodeImpl<Impl>>
- IsConstant() const {
+ constexpr auto IsConstant() const
+ -> decltype(this->WithOpcode(HloOpcode::kConstant)) {
return WithOpcode(HloOpcode::kConstant);
}
// Modifies the pattern to match only if the instruction is not a constant.
- constexpr HloInstructionPattern<HloInstructionType,
- HloInstructionPatternOpcodeImpl<Impl>>
- IsNonConstant() const {
+ constexpr auto IsNonConstant() const
+ -> decltype(this->WithoutOpcode(HloOpcode::kConstant)) {
return WithoutOpcode(HloOpcode::kConstant);
}
// Modifies the pattern to match only if the instruction has a shape that
// matches the given pattern.
template <typename ShapeType, typename ShapeImpl>
- constexpr HloInstructionPattern<
- HloInstructionType,
- HloInstructionPatternShapeImpl<Impl, ShapeType, ShapeImpl>>
- WithShape(const ShapePattern<ShapeType, ShapeImpl>& shape) const {
- return HloInstructionPattern<
- HloInstructionType,
- HloInstructionPatternShapeImpl<Impl, ShapeType, ShapeImpl>>(
- HloInstructionPatternShapeImpl<Impl, ShapeType, ShapeImpl>(impl_,
- shape),
- matched_inst_);
+ constexpr auto WithShape(const ShapePattern<ShapeType, ShapeImpl>& shape)
+ const -> decltype(this->AppendImpl(
+ HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape))) {
+ return AppendImpl(
+ HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape));
}
// Modifies the pattern to match only if the instruction has an operand that
// matches the given pattern.
template <typename OperandType, typename OperandImpl>
- constexpr HloInstructionPattern<
- HloInstructionType,
- HloInstructionPatternOperandImpl<Impl, OperandType, OperandImpl>>
- WithOperand(
+ constexpr auto WithOperand(
int64 operand_index,
- const HloInstructionPattern<OperandType, OperandImpl>& operand) const {
- return HloInstructionPattern<
- HloInstructionType,
- HloInstructionPatternOperandImpl<Impl, OperandType, OperandImpl>>(
- HloInstructionPatternOperandImpl<Impl, OperandType, OperandImpl>(
- impl_, operand_index, operand),
- matched_inst_);
+ const HloInstructionPattern<OperandType, OperandImpl>& operand) const
+ -> decltype(this->AppendImpl(
+ HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
+ operand_index, operand))) {
+ return AppendImpl(
+ HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
+ operand_index, operand));
}
// Modifies the pattern to match only if the instruction is a fusion node with
// the given kind.
- constexpr HloInstructionPattern<HloInstructionType,
- HloInstructionPatternFusionKindImpl<Impl>>
- WithFusionKind(HloInstruction::FusionKind kind) const {
- return HloInstructionPattern<HloInstructionType,
- HloInstructionPatternFusionKindImpl<Impl>>(
- HloInstructionPatternFusionKindImpl<Impl>(impl_, kind), matched_inst_);
+ constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const
+ -> decltype(this->AppendImpl(HloInstructionPatternFusionKindImpl(kind))) {
+ return AppendImpl(HloInstructionPatternFusionKindImpl(kind));
}
// Modifies the pattern to match only if the instruction is a
// get-tuple-element with the given tuple index.
- constexpr HloInstructionPattern<HloInstructionType,
- HloInstructionPatternTupleIndexImpl<Impl>>
- WithTupleIndex(int64 tuple_index) const {
- return HloInstructionPattern<HloInstructionType,
- HloInstructionPatternTupleIndexImpl<Impl>>(
- HloInstructionPatternTupleIndexImpl<Impl>(impl_, tuple_index),
- matched_inst_);
+ constexpr auto WithTupleIndex(int64 tuple_index) const -> decltype(
+ this->AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index))) {
+ return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index));
}
private:
+ template <typename Predicate>
+ constexpr auto WithPredicate(Predicate pred) const -> decltype(
+ this->AppendImpl(HloPredicatePatternImpl<HloInstruction, Predicate>(
+ std::move(pred)))) {
+ return AppendImpl(
+ HloPredicatePatternImpl<HloInstruction, Predicate>(std::move(pred)));
+ }
+
+ friend struct PatternFriend;
+
Impl impl_;
HloInstructionType** matched_inst_;
};
@@ -1005,31 +1102,50 @@ XLA_UNOP_PATTERN(Transpose)
.WithOperand(0, std::forward<Lhs>(lhs)) \
.WithOperand(1, std::forward<Rhs>(rhs)); \
}
-XLA_BINOP_PATTERN(Add)
+
+#define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \
+ XLA_BINOP_PATTERN(NAME) \
+ \
+ template <typename Lhs, typename Rhs> \
+ inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \
+ ->decltype(AnyOf<HloInstruction>(NAME(lhs, rhs), NAME(rhs, lhs))) { \
+ return AnyOf<HloInstruction>(NAME(lhs, rhs), NAME(rhs, lhs)); \
+ } \
+ \
+ template <typename HloInstructionType, typename Lhs, typename Rhs> \
+ inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
+ Rhs&& rhs) \
+ ->decltype(AnyOf<HloInstructionType>(NAME(matched_inst, lhs, rhs), \
+ NAME(matched_inst, rhs, lhs))) { \
+ return AnyOf<HloInstructionType>(NAME(matched_inst, lhs, rhs), \
+ NAME(matched_inst, rhs, lhs)); \
+ }
+XLA_COMMUTATIVE_BINOP_PATTERN(Add)
XLA_BINOP_PATTERN(Atan2)
XLA_BINOP_PATTERN(Divide)
XLA_BINOP_PATTERN(Complex)
XLA_BINOP_PATTERN(Dot)
-XLA_BINOP_PATTERN(Eq)
+XLA_COMMUTATIVE_BINOP_PATTERN(Eq)
XLA_BINOP_PATTERN(Gather)
XLA_BINOP_PATTERN(Ge)
XLA_BINOP_PATTERN(Gt)
XLA_BINOP_PATTERN(Le)
XLA_BINOP_PATTERN(Lt)
-XLA_BINOP_PATTERN(Maximum)
-XLA_BINOP_PATTERN(Minimum)
-XLA_BINOP_PATTERN(Multiply)
-XLA_BINOP_PATTERN(Ne)
+XLA_COMMUTATIVE_BINOP_PATTERN(Maximum)
+XLA_COMMUTATIVE_BINOP_PATTERN(Minimum)
+XLA_COMMUTATIVE_BINOP_PATTERN(Multiply)
+XLA_COMMUTATIVE_BINOP_PATTERN(Ne)
XLA_BINOP_PATTERN(Outfeed)
XLA_BINOP_PATTERN(Power)
XLA_BINOP_PATTERN(Remainder)
XLA_BINOP_PATTERN(Send)
XLA_BINOP_PATTERN(Subtract)
-XLA_BINOP_PATTERN(And)
-XLA_BINOP_PATTERN(Or)
+XLA_COMMUTATIVE_BINOP_PATTERN(And)
+XLA_COMMUTATIVE_BINOP_PATTERN(Or)
XLA_BINOP_PATTERN(ShiftLeft)
XLA_BINOP_PATTERN(ShiftRightArithmetic)
XLA_BINOP_PATTERN(ShiftRightLogical)
+#undef XLA_COMMUTATIVE_BINOP_PATTERN
#undef XLA_BINOP_PATTERN
// Helpers for ternary instructions.
@@ -1070,6 +1186,30 @@ XLA_TERNOP_PATTERN(Clamp);
XLA_TERNOP_PATTERN(Select);
#undef XLA_TERNOP_PATTERN
+namespace detail {
+struct PatternFriend {
+ template <typename T>
+ static auto ConstantScalar(T constant) -> decltype(
+ Constant()
+ .WithShape(match::Shape().IsScalar())
+ .WithPredicate(
+ std::declval<std::function<bool(const HloInstruction*)>>())) {
+ std::function<bool(const HloInstruction*)> pred =
+ [constant](const HloInstruction* instr) {
+ const auto& literal = Cast<HloConstantInstruction>(instr)->literal();
+ auto status_or_const = LiteralUtil::CreateR0(constant).Convert(
+ literal.shape().element_type());
+ return status_or_const.ok() &&
+ literal == status_or_const.ConsumeValueOrDie();
+ };
+
+ return Constant()
+ .WithShape(match::Shape().IsScalar())
+ .WithPredicate(std::move(pred));
+ }
+};
+} // namespace detail
+
// Helpers for matching non-constant instructions.
inline auto NonConstant() -> decltype(Op().IsNonConstant()) {
return Op().IsNonConstant();
@@ -1107,6 +1247,12 @@ inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg,
.WithTupleIndex(tuple_index);
}
+template <typename T>
+inline auto ConstantScalar(T constant)
+ -> decltype(detail::PatternFriend::ConstantScalar(constant)) {
+ return detail::PatternFriend::ConstantScalar(constant);
+}
+
} // namespace match
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
index a530581c34..3ab7b7fd71 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc
+++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
@@ -211,5 +211,188 @@ TEST(PatternMatcherTest, GetTupleElement) {
EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1)));
}
+TEST(PatternMatcherTest, AnyOf) {
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ EXPECT_TRUE(
+ Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
+ match::ConstantScalar(1))));
+ EXPECT_TRUE(
+ Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(1),
+ match::ConstantScalar(0))));
+ EXPECT_FALSE(
+ Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
+ match::ConstantScalar(2))));
+}
+
+TEST(PatternMatcherTest, ConstantScalar) {
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module ENTRY test { ROOT constant = f16[] constant(42) })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ EXPECT_TRUE(Match(root, match::ConstantScalar(42)));
+ EXPECT_FALSE(Match(root, match::ConstantScalar(41)));
+ EXPECT_FALSE(Match(root, match::ConstantScalar(0)));
+}
+
+TEST(PatternMatcherTest, NoMatchConstantScalar) {
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module ENTRY test { ROOT v = f16[] parameter(0) })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ EXPECT_FALSE(Match(root, match::ConstantScalar(42)));
+}
+
+TEST(PatternMatcherTest, MultiplyAnyOrder) {
+ using match::ConstantScalar;
+ using match::MultiplyAnyOrder;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+ ENTRY test {
+ lhs = f16[] constant(42)
+ rhs = f16[] constant(52)
+ ROOT multiply = f16[] multiply(lhs, rhs)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+ const HloInstruction* instr;
+
+ EXPECT_TRUE(Match(
+ root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))));
+ EXPECT_TRUE(Match(
+ root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42))));
+}
+
+TEST(PatternMatcherTest, AnyOfShortCircuit) {
+ using match::AnyOf;
+ using match::Multiply;
+ using match::Op;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+ ENTRY test {
+ lhs = f16[] constant(42)
+ rhs = f16[] constant(52)
+ ROOT multiply = f16[] multiply(lhs, rhs)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ {
+ const HloInstruction* mul = nullptr;
+ const HloInstruction* any = nullptr;
+
+ ASSERT_TRUE(Match(
+ root, AnyOf<HloInstruction>(Multiply(&mul, Op(), Op()), Op(&any))));
+ EXPECT_NE(nullptr, mul);
+ EXPECT_EQ(nullptr, any);
+ }
+ {
+ const HloInstruction* mul = nullptr;
+ const HloInstruction* any = nullptr;
+
+ ASSERT_TRUE(Match(
+ root, AnyOf<HloInstruction>(Op(&any), Multiply(&mul, Op(), Op()))));
+ EXPECT_NE(nullptr, any);
+ EXPECT_EQ(nullptr, mul);
+ }
+}
+
+TEST(PatternMatcherTest, AllOf) {
+ using match::AllOf;
+ using match::Broadcast;
+ using match::Constant;
+ using match::Op;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar());
+ auto f16_pattern = Constant().WithShape(match::Shape().WithElementType(F16));
+ ASSERT_TRUE(Match(root, scalar_pattern));
+ ASSERT_TRUE(Match(root, f16_pattern));
+ EXPECT_TRUE(Match(root, AllOf<HloInstruction>(scalar_pattern, f16_pattern)));
+ EXPECT_TRUE(Match(root, AllOf<HloInstruction>(f16_pattern, scalar_pattern)));
+ EXPECT_FALSE(
+ Match(root, AllOf<HloInstruction>(Broadcast(Op()), f16_pattern)));
+ EXPECT_FALSE(
+ Match(root, AllOf<HloInstruction>(Broadcast(Op()), scalar_pattern)));
+}
+
+TEST(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
+ using match::AllOf;
+ using match::Broadcast;
+ using match::Constant;
+ using match::Op;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+ ENTRY test {
+ ROOT v = f16[] constant(42)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ const HloInstruction* constant = nullptr;
+ ASSERT_FALSE(
+ Match(root, AllOf<HloInstruction>(Constant(&constant), Broadcast(Op()))));
+ EXPECT_EQ(nullptr, constant);
+ ASSERT_TRUE(Match(root, Constant(&constant)));
+ EXPECT_NE(nullptr, constant);
+}
+
+TEST(PatternMatcherTest, TestNoCapture) {
+ using match::Constant;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+ ENTRY test {
+ ROOT v = f16[] constant(42)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ const HloInstruction* constant = nullptr;
+ ASSERT_TRUE(Match(root, Constant(&constant), {/*capture=*/false}));
+ EXPECT_EQ(nullptr, constant);
+}
+
+TEST(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) {
+ using match::Add;
+ using match::AddAnyOrder;
+ using match::AnyOf;
+ using match::Op;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+ ENTRY test {
+ u = f16[] parameter(0)
+ v = f16[] parameter(1)
+ ROOT add = f16[] add(u, v)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ const HloInstruction* addend0 = nullptr;
+ const HloInstruction* addend1 = nullptr;
+ const HloInstruction* addend2 = nullptr;
+ auto add2_pattern = Add(Op(&addend0), Op(&addend1));
+ auto add3_pattern = AnyOf<HloInstruction>(
+ AddAnyOrder(add2_pattern, Op(&addend2)), add2_pattern, Op(&addend0));
+
+ ASSERT_TRUE(Match(root, add3_pattern));
+ EXPECT_NE(nullptr, addend0);
+ EXPECT_NE(nullptr, addend1);
+ EXPECT_EQ(nullptr, addend2);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc
index 178a78ede0..c522e7ae23 100644
--- a/tensorflow/compiler/xla/service/platform_util.cc
+++ b/tensorflow/compiler/xla/service/platform_util.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "absl/strings/ascii.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -217,9 +218,12 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) {
if (platform->id() == se::host::kHostPlatformId) {
// On host "devices", StreamExecutor exports a device for each hardware
// thread. Because we parallelize a single computation across threads, it
- // doesn't make sense to expose these as separate devices, so fix the number
- // of devices to one.
- device_count = 1;
+ // doesn't make sense to expose these as separate devices, so by default we
+ // fix the number of devices to one. However we do let the user override
+ // this behavior to help run tests on the host that run models in parallel
+ // across multiple devices.
+ device_count = legacy_flags::GetDebugOptionsFromFlags()
+ .xla_force_host_platform_device_count();
}
std::vector<se::StreamExecutor*> stream_executors(device_count, nullptr);
VLOG(1) << "Initializing devices";
diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
index 256b231e3a..4bb22428f3 100644
--- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h
+++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
@@ -29,7 +29,7 @@ namespace xla {
// HLO pass which inserts reduce-precision instructions into the HLO graph, for
// purposes of experimenting with the effects of reduced-precision storage of
// intermediate values.
-class ReducePrecisionInsertion : public HloPassInterface {
+class ReducePrecisionInsertion : public HloModulePass {
using InstructionFilterFunction = std::function<bool(const HloInstruction*)>;
public:
diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h
index 1e86a0823a..a3db439e34 100644
--- a/tensorflow/compiler/xla/service/reshape_mover.h
+++ b/tensorflow/compiler/xla/service/reshape_mover.h
@@ -24,7 +24,7 @@ namespace xla {
// This now only moves them outputward across elementwise ops all whose operands
// are equivalent Reshapes or Transposes, but in future could potentially move
// them inputward also.
-class ReshapeMover : public HloPassInterface {
+class ReshapeMover : public HloModulePass {
public:
absl::string_view name() const override { return "reshape-mover"; }
diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc
index 2f4b2667c4..de7aee262e 100644
--- a/tensorflow/compiler/xla/service/scatter_expander.cc
+++ b/tensorflow/compiler/xla/service/scatter_expander.cc
@@ -155,6 +155,53 @@ static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
return MakeConcatHlo(expanded_index_components, /*dimension=*/0);
}
+static StatusOr<HloInstruction*> CheckIndexValidity(
+ HloComputation* computation, HloInstruction* index,
+ absl::Span<const int64> operand_dims, absl::Span<const int64> window_sizes,
+ HloModule* module) {
+ DCHECK_NE(nullptr, module);
+ DCHECK_EQ(operand_dims.size(), window_sizes.size());
+
+ // Valid range for the index: [0, operand_dims - window_sizes]
+
+ // Check if the index has any negative values.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * zero_index,
+ BroadcastZeros(computation, index->shape().element_type(),
+ AsInt64Slice(index->shape().dimensions())));
+ TF_ASSIGN_OR_RETURN(HloInstruction * negative_index_check,
+ MakeBinaryHlo(HloOpcode::kLe, zero_index, index));
+
+ // Check if the index is OOB w.r.t. the operand dimensions and window sizes.
+ std::vector<int64> max_valid_index(operand_dims.size());
+ for (int i = 0; i < operand_dims.size(); ++i) {
+ max_valid_index[i] = operand_dims[i] - window_sizes[i];
+ }
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * max_valid_index_constant,
+ MakeR1ConstantHlo<int64>(computation, index->shape().element_type(),
+ max_valid_index));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * oob_index_check,
+ MakeBinaryHlo(HloOpcode::kGe, max_valid_index_constant, index));
+
+ // Combine the results of the two checks above.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * valid_index,
+ MakeBinaryHlo(HloOpcode::kAnd, negative_index_check, oob_index_check));
+
+ // Reduce the index validity check vector into a scalar predicate.
+ auto reduction_init = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * valid_index_reduced,
+ MakeReduceHlo(valid_index, reduction_init, HloOpcode::kAnd, module));
+
+ // Return a broadcasted value of the scalar predicate to the same size as the
+ // window.
+ return MakeBroadcastHlo(valid_index_reduced, {}, window_sizes);
+}
+
// Body of the while loop that performs the scatter operation using other HLOs.
static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
HloInstruction* scatter, HloInstruction* induction_var,
@@ -222,7 +269,16 @@ static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
InsertDegenerateDims(update_slice_for_scatter,
AsInt64Slice(dim_numbers.inserted_window_dims())));
- // Extact the slice to update from `operand` tensor.
+ // Note that the following transformation assumes that both DynamicSlice and
+ // DynamicUpdateSlice follow the same semantics for OOB indices. For example,
+ // if there are negative indices and DynamicSlice uses "clamping" semantics,
+ // then the extracted data will be "shifted". Since DynamicUpdateSlice also
+ // follows the same "clamping" semantics, writing the update will also be
+ // "shifted" by exactly the same amount. So, this transformation is correct as
+ // long as the semantics of handling OOB indices remain the same in
+ // DynamicSlice and DynamicUpdateSlice.
+
+ // Extract the slice to update from `operand` tensor.
const Shape& update_slice_shape = update_slice_with_dims_inserted->shape();
TF_ASSIGN_OR_RETURN(
HloInstruction * operand_slice_to_update,
@@ -237,10 +293,24 @@ static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
MakeMapHlo({operand_slice_to_update, update_slice_with_dims_inserted},
scatter->to_apply()));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * is_index_valid,
+ CheckIndexValidity(
+ operand->parent(), scatter_slice_start,
+ AsInt64Slice(operand->shape().dimensions()),
+ AsInt64Slice(update_slice_with_dims_inserted->shape().dimensions()),
+ scatter->GetModule()));
+
+ // Select the updated operand only if the index is valid. If not, select the
+ // original value.
+ TF_ASSIGN_OR_RETURN(HloInstruction * update_to_apply,
+ MakeSelectHlo(is_index_valid, updated_operand_slice,
+ operand_slice_to_update));
+
// Write the updated value of the slice into `operand` tensor.
- TF_ASSIGN_OR_RETURN(HloInstruction * updated_operand,
- MakeDynamicUpdateSliceHlo(operand, updated_operand_slice,
- scatter_slice_start));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * updated_operand,
+ MakeDynamicUpdateSliceHlo(operand, update_to_apply, scatter_slice_start));
return StatusOr<std::vector<HloInstruction*>>{
{updated_operand, scatter_indices, updates}};
diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h
index 14f062c89c..559a85dccf 100644
--- a/tensorflow/compiler/xla/service/scatter_expander.h
+++ b/tensorflow/compiler/xla/service/scatter_expander.h
@@ -20,7 +20,7 @@ limitations under the License.
namespace xla {
-class ScatterExpander : public HloPassInterface {
+class ScatterExpander : public HloModulePass {
public:
absl::string_view name() const override { return "scatter_expander"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 74bdf2a2e3..7194b2cafd 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1665,10 +1665,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (input_features != kernel_input_features * feature_group_count) {
return InvalidArgument(
"Expected LHS feature dimension (value %d) to match RHS "
- "input feature dimension * feature_group_count (value %d); "
+ "input feature dimension * feature_group_count (value %d * %d = %d); "
"got <conv>(%s, %s)\n"
"Dimension numbers: {%s}.",
- input_features, kernel_input_features * feature_group_count,
+ input_features, kernel_input_features, feature_group_count,
+ kernel_input_features * feature_group_count,
ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
dnums.DebugString());
}
diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc
index 5d1cd1c442..ec09dff924 100644
--- a/tensorflow/compiler/xla/service/stream_pool.cc
+++ b/tensorflow/compiler/xla/service/stream_pool.cc
@@ -28,8 +28,14 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
// Re-use an existing stream from the pool.
stream = std::move(streams_.back());
streams_.pop_back();
- VLOG(1) << stream->DebugStreamPointers()
- << " StreamPool reusing existing stream";
+ if (stream->ok()) {
+ VLOG(1) << stream->DebugStreamPointers()
+ << " StreamPool reusing existing stream";
+ } else {
+ VLOG(1) << stream->DebugStreamPointers()
+ << " stream was not ok, StreamPool deleting";
+ stream = nullptr;
+ }
}
}
diff --git a/tensorflow/compiler/xla/service/stream_pool_test.cc b/tensorflow/compiler/xla/service/stream_pool_test.cc
index aaf5c37b0d..92f47579d3 100644
--- a/tensorflow/compiler/xla/service/stream_pool_test.cc
+++ b/tensorflow/compiler/xla/service/stream_pool_test.cc
@@ -132,5 +132,39 @@ TEST_F(StreamPoolTest, BadStreamDiscarded) {
EXPECT_EQ(stream2_ptr, stream3_ptr);
}
+TEST_F(StreamPoolTest, BadStreamAfterReturnDiscarded) {
+ std::unique_ptr<se::StreamExecutor> executor = NewStreamExecutor();
+ StreamPool pool;
+
+ // Borrow a stream.
+ StreamPool::Ptr stream1 = pool.BorrowStream(executor.get());
+ EXPECT_TRUE(stream1->ok());
+
+ // Return the stream, but hold a handle to it.
+ se::Stream* stream1_ptr = stream1.get();
+ stream1 = nullptr;
+
+ // Now stream1 is back in the pool, force an error on the stream. Here we call
+ // a method that requires DNN support, which we know the Host platform doesn't
+ // support.
+ stream1_ptr->ThenDepthConcatenate({}, {}, nullptr);
+ EXPECT_FALSE(stream1_ptr->ok());
+
+ // Borrow stream2.
+ StreamPool::Ptr stream2 = pool.BorrowStream(executor.get());
+ EXPECT_TRUE(stream2->ok());
+
+ // The underlying streams should be different. They would have been
+ // the same, but since we forced an error on stream1, it cannot be
+ // put back into the pool. Sadly we can't just check:
+ // EXPECT_NE(stream1_ptr, stream2_ptr);
+ //
+ // The above should hold logically, but it may fail if the new
+ // stream instance allocated for stream2 happens to reside in the
+ // same memory address as stream1, which has been deleted.
+ //
+ // The check that stream2->ok() serves as a good-enough check.
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h
index 3e5aa2db60..f95f982eb8 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.h
+++ b/tensorflow/compiler/xla/service/transpose_folding.h
@@ -23,7 +23,7 @@ namespace xla {
// HLO pass that folds transpose operators into Dot operators, where the Dot
// operator is implemented by a GEMM kernel that can transpose its inputs.
-class TransposeFolding : public HloPassInterface {
+class TransposeFolding : public HloModulePass {
public:
using OperandIndices = std::vector<int64>;
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h
index 8c91d6e69d..e126a53023 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier.h
+++ b/tensorflow/compiler/xla/service/tuple_simplifier.h
@@ -25,7 +25,7 @@ namespace xla {
// A pass which simplifies patterns of Tuple and GetTupleElement instructions in
// the module.
-class TupleSimplifier : public HloPassInterface {
+class TupleSimplifier : public HloModulePass {
public:
TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {}
explicit TupleSimplifier(bool exclude_entry_computation);
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
index 2dba7d7f75..577bad6c70 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
@@ -50,7 +50,7 @@ namespace xla {
// conditions as well.
//
// TODO(b/79121449): We should also sink broadcasts of constants.
-class WhileLoopConstantSinking : public HloPassInterface {
+class WhileLoopConstantSinking : public HloModulePass {
public:
~WhileLoopConstantSinking() override = default;
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
index 2cdf20ce80..3031899f71 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
@@ -25,7 +25,7 @@ namespace xla {
// HLO pass that rewrites while loops to hoist loop invariant instructions in
// the while body into the computation that contains the while instruction.
-class WhileLoopInvariantCodeMotion : public HloPassInterface {
+class WhileLoopInvariantCodeMotion : public HloModulePass {
public:
// If `hoist_constants` is true then constants are always hoisted out of while
// loop bodies. Otherwise they are only hoisted out if they enable other
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
index 6a7bfe3f12..9a74f22395 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
@@ -252,7 +252,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
// Create the new while condition, body, and init value.
std::unique_ptr<HloComputation> new_while_cond =
while_cond->CloneWithReplacements(
- make_while_computation_replacements(while_cond));
+ make_while_computation_replacements(while_cond), /*extras=*/{});
std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
while_body_replacements = make_while_computation_replacements(while_body);
@@ -265,7 +265,8 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
while_body_replacements.emplace(
while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems));
std::unique_ptr<HloComputation> new_while_body =
- while_body->CloneWithReplacements(std::move(while_body_replacements));
+ while_body->CloneWithReplacements(std::move(while_body_replacements),
+ /*extras=*/{});
// Add a new while_init instruction that repackages the old while_init
// instruction's elements. We rely on the AlgebraicSimplifier and DCE to
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h
index 78024f14dc..0bc5a0107b 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.h
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h
@@ -30,7 +30,7 @@ namespace xla {
// - Elements of a while loop's tuple that the loop doesn't use are removed
// from the tuple.
//
-class WhileLoopSimplifier : public HloPassInterface {
+class WhileLoopSimplifier : public HloModulePass {
public:
~WhileLoopSimplifier() override {}
absl::string_view name() const override { return "simplify-while-loops"; }
diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
index a7f0e207eb..87294120d5 100644
--- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
+++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
@@ -21,7 +21,7 @@ limitations under the License.
// HLO pass that replaces zero sized Hlos with a zero sized constant literal.
namespace xla {
-class ZeroSizedHloElimination : public HloPassInterface {
+class ZeroSizedHloElimination : public HloModulePass {
public:
StatusOr<bool> Run(HloModule* module) override;
absl::string_view name() const override {
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 96c80fd577..020c167ee9 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -422,8 +422,11 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
}
/* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
- CHECK(IsArray(shape)) << ShapeUtil::HumanString(shape);
- CHECK_EQ(shape.dimensions_size(), Rank(shape));
+ DCHECK(IsArray(shape)) << ShapeUtil::HumanString(shape);
+ DCHECK_EQ(shape.dimensions_size(), Rank(shape));
+ if (shape.dimensions().size() == 1) {
+ return shape.dimensions()[0];
+ }
return std::accumulate<decltype(shape.dimensions().begin()), int64>(
shape.dimensions().begin(), shape.dimensions().end(), 1LL,
std::multiplies<int64>());
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 623ae39de8..d8bb27beae 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <initializer_list>
#include <string>
+#include "absl/base/macros.h"
#include "absl/container/inlined_vector.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
@@ -479,8 +480,7 @@ class ShapeUtil {
// Shorthand for testing whether a shape is of a given element type and
// sequence of dimensions.
- //
- // DEPRECATED: Use Equal() instead.
+ ABSL_DEPRECATED("Use Equal() instead.")
static bool ShapeIs(const Shape& shape, PrimitiveType element_type,
std::initializer_list<int64> dimensions);
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 30e3077edb..f474ecb18c 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -29,6 +29,10 @@ load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites"
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_test_macros")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
# Generate test_suites for all backends, named "${backend}_tests".
generate_backend_suites()
@@ -150,11 +154,31 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/core:lib",
- "//tensorflow/core:test",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
],
)
+tf_cc_test(
+ name = "hlo_verified_test_base_test",
+ srcs = ["hlo_verified_test_base_test.cc"],
+ deps = [
+ ":hlo_test_base",
+ ":hlo_verified_test_base",
+ ":test_macros_cpu",
+ ":test_utils",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_computation",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/service:hlo_verifier",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
tf_cc_binary(
name = "local_client_aot_test_helper",
srcs = ["local_client_aot_test_helper.cc"],
@@ -1797,7 +1821,7 @@ xla_test(
tf_cc_test(
name = "llvm_compiler_test",
srcs = ["llvm_compiler_test.cc"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:test_helpers",
@@ -2096,7 +2120,7 @@ tf_cc_test(
name = "sample_file_test",
srcs = ["sample_file_test.cc"],
data = ["isolated_convolution.hlo"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
":hlo_test_base",
"//tensorflow/compiler/xla:test",
@@ -2144,3 +2168,21 @@ xla_test(
"//tensorflow/core:lib",
],
)
+
+tf_cc_test(
+ name = "multiple_devices_on_host_test",
+ srcs = ["multiple_devices_on_host_test.cc"],
+ args = ["--xla_force_host_platform_device_count=4"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/service:cpu_plugin",
+ "//tensorflow/compiler/xla/service:platform_util",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl
index 53f2c3bfbf..05d4d04034 100644
--- a/tensorflow/compiler/xla/tests/build_defs.bzl
+++ b/tensorflow/compiler/xla/tests/build_defs.bzl
@@ -3,256 +3,266 @@
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
all_backends = ["cpu", "gpu"] + plugins.keys()
def filter_backends(backends):
- """Removes "gpu" from a backend list if CUDA is not enabled.
-
- This allows us to simply hardcode lists including "gpu" here and in the
- BUILD file, without causing failures when CUDA isn't enabled.'
-
- Args:
- backends: A list of backends to filter.
-
- Returns:
- The filtered list of backends.
- """
- if cuda_is_configured():
- return backends
- else:
- return [backend for backend in backends if backend != "gpu"]
-
-
-def xla_test(name,
- srcs,
- deps,
- xla_test_library_deps=[],
- backends=[],
- blacklisted_backends=[],
- args=[],
- tags=[],
- copts=[],
- data=[],
- backend_tags={},
- backend_args={},
- **kwargs):
- """Generates cc_test targets for the given XLA backends.
-
- This rule generates a cc_test target for one or more XLA backends and also a
- platform-agnostic cc_library rule. The arguments are identical to cc_test with
- two additions: 'backends' and 'backend_args'. 'backends' specifies the
- backends to generate tests for ("cpu", "gpu"), and
- 'backend_args'/'backend_tags' specifies backend-specific args parameters to
- use when generating the cc_test.
-
- The name of the cc_tests are the provided name argument with the backend name
- appended, and the cc_library target name is the provided name argument with
- "_lib" appended. For example, if name parameter is "foo_test", then the cpu
- test target will be "foo_test_cpu" and the cc_library target is "foo_lib".
-
- The cc_library target can be used to link with other plugins outside of
- xla_test.
-
- The build rule also defines a test suite ${name} which includes the tests for
- each of the supported backends.
-
- Each generated cc_test target has a tag indicating which backend the test is
- for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These
- tags can be used to gather tests for a particular backend into a test_suite.
-
- Examples:
-
- # Generates the targets: foo_test_cpu and foo_test_gpu.
- xla_test(
- name = "foo_test",
- srcs = ["foo_test.cc"],
- backends = ["cpu", "gpu"],
- deps = [...],
- )
+ """Removes "gpu" from a backend list if CUDA is not enabled.
- # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu
- # includes the additional arg "--special_cpu_flag".
- xla_test(
- name = "bar_test",
- srcs = ["bar_test.cc"],
- backends = ["cpu", "gpu"],
- backend_args = {"cpu": ["--special_cpu_flag"]}
- deps = [...],
- )
+ This allows us to simply hardcode lists including "gpu" here and in the
+ BUILD file, without causing failures when CUDA isn't enabled.'
- The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND}
- to the value 1 where ${BACKEND} is the uppercase name of the backend.
-
- Args:
- name: Name of the target.
- srcs: Sources for the target.
- deps: Dependencies of the target.
- xla_test_library_deps: If set, the generated test targets will depend on the
- respective cc_libraries generated by the xla_test_library rule.
- backends: A list of backends to generate tests for. Supported values: "cpu",
- "gpu". If this list is empty, the test will be generated for all supported
- backends.
- blacklisted_backends: A list of backends to NOT generate tests for.
- args: Test arguments for the target.
- tags: Tags for the target.
- copts: Additional copts to pass to the build.
- data: Additional data to pass to the build.
- backend_tags: A dict mapping backend name to list of additional tags to
- use for that target.
- backend_args: A dict mapping backend name to list of additional args to
- use for that target.
- **kwargs: Additional keyword arguments to pass to native.cc_test.
- """
- test_names = []
- if not backends:
- backends = all_backends
-
- backends = [backend for backend in backends
- if backend not in blacklisted_backends]
-
- native.cc_library(
- name="%s_lib" % name,
- srcs=srcs,
- copts=copts,
- testonly=True,
- deps=deps + ["//tensorflow/compiler/xla/tests:test_macros_header"],
- )
-
- for backend in filter_backends(backends):
- test_name = "%s_%s" % (name, backend)
- this_backend_tags = ["xla_%s" % backend]
- this_backend_copts = []
- this_backend_args = backend_args.get(backend, [])
- this_backend_data = []
- if backend == "cpu":
- backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
- backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
- elif backend == "gpu":
- backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"]
- backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"]
- this_backend_tags += ["requires-gpu-sm35"]
- elif backend in plugins:
- backend_deps = []
- backend_deps += plugins[backend]["deps"]
- this_backend_copts += plugins[backend]["copts"]
- this_backend_tags += plugins[backend]["tags"]
- this_backend_args += plugins[backend]["args"]
- this_backend_data += plugins[backend]["data"]
- else:
- fail("Unknown backend %s" % backend)
-
- if xla_test_library_deps:
- for lib_dep in xla_test_library_deps:
- backend_deps += ["%s_%s" % (lib_dep, backend)]
-
- tf_cc_test(
- name=test_name,
- srcs=srcs,
- tags=tags + backend_tags.get(backend, []) + this_backend_tags,
- extra_copts=copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
- this_backend_copts,
- args=args + this_backend_args,
- deps=deps + backend_deps,
- data=data + this_backend_data,
- **kwargs)
-
- test_names.append(test_name)
-
- native.test_suite(name=name, tests=test_names)
-
-def xla_test_library(name,
- srcs,
- hdrs=[],
- deps=[],
- backends=[]):
- """Generates cc_library targets for the given XLA backends.
-
- This rule forces the sources to be compiled for each backend so that the
- backend specific macros could expand correctly. It's useful when test targets
- in different directories referring to the same sources but test with different
- arguments.
-
- Examples:
-
- # Generates the targets: foo_test_library_cpu and foo_test_gpu.
- xla_test_library(
- name = "foo_test_library",
- srcs = ["foo_test.cc"],
- backends = ["cpu", "gpu"],
- deps = [...],
- )
- # Then use the xla_test rule to generate test targets:
- xla_test(
- name = "foo_test",
- srcs = [],
- backends = ["cpu", "gpu"],
- deps = [...],
- xla_test_library_deps = [":foo_test_library"],
- )
+ Args:
+ backends: A list of backends to filter.
- Args:
- name: Name of the target.
- srcs: Sources for the target.
- hdrs: Headers for the target.
- deps: Dependencies of the target.
- backends: A list of backends to generate libraries for.
- Supported values: "cpu", "gpu". If this list is empty, the
- library will be generated for all supported backends.
- """
-
- if not backends:
- backends = all_backends
-
- for backend in filter_backends(backends):
- this_backend_copts = []
- if backend in ["cpu", "gpu"]:
- backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend]
- elif backend in plugins:
- backend_deps = plugins[backend]["deps"]
- this_backend_copts += plugins[backend]["copts"]
+ Returns:
+ The filtered list of backends.
+ """
+ if cuda_is_configured():
+ return backends
else:
- fail("Unknown backend %s" % backend)
+ return [backend for backend in backends if backend != "gpu"]
+
+def xla_test(
+ name,
+ srcs,
+ deps,
+ xla_test_library_deps = [],
+ backends = [],
+ blacklisted_backends = [],
+ args = [],
+ tags = [],
+ copts = [],
+ data = [],
+ backend_tags = {},
+ backend_args = {},
+ **kwargs):
+ """Generates cc_test targets for the given XLA backends.
+
+ This rule generates a cc_test target for one or more XLA backends and also a
+ platform-agnostic cc_library rule. The arguments are identical to cc_test with
+ two additions: 'backends' and 'backend_args'. 'backends' specifies the
+ backends to generate tests for ("cpu", "gpu"), and
+ 'backend_args'/'backend_tags' specifies backend-specific args parameters to
+ use when generating the cc_test.
+
+ The name of the cc_tests are the provided name argument with the backend name
+ appended, and the cc_library target name is the provided name argument with
+ "_lib" appended. For example, if name parameter is "foo_test", then the cpu
+ test target will be "foo_test_cpu" and the cc_library target is "foo_lib".
+
+ The cc_library target can be used to link with other plugins outside of
+ xla_test.
+
+ The build rule also defines a test suite ${name} which includes the tests for
+ each of the supported backends.
+
+ Each generated cc_test target has a tag indicating which backend the test is
+ for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These
+ tags can be used to gather tests for a particular backend into a test_suite.
+
+ Examples:
+
+ # Generates the targets: foo_test_cpu and foo_test_gpu.
+ xla_test(
+ name = "foo_test",
+ srcs = ["foo_test.cc"],
+ backends = ["cpu", "gpu"],
+ deps = [...],
+ )
+
+ # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu
+ # includes the additional arg "--special_cpu_flag".
+ xla_test(
+ name = "bar_test",
+ srcs = ["bar_test.cc"],
+ backends = ["cpu", "gpu"],
+ backend_args = {"cpu": ["--special_cpu_flag"]}
+ deps = [...],
+ )
+
+ The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND}
+ to the value 1 where ${BACKEND} is the uppercase name of the backend.
+
+ Args:
+ name: Name of the target.
+ srcs: Sources for the target.
+ deps: Dependencies of the target.
+ xla_test_library_deps: If set, the generated test targets will depend on the
+ respective cc_libraries generated by the xla_test_library rule.
+ backends: A list of backends to generate tests for. Supported values: "cpu",
+ "gpu". If this list is empty, the test will be generated for all supported
+ backends.
+ blacklisted_backends: A list of backends to NOT generate tests for.
+ args: Test arguments for the target.
+ tags: Tags for the target.
+ copts: Additional copts to pass to the build.
+ data: Additional data to pass to the build.
+ backend_tags: A dict mapping backend name to list of additional tags to
+ use for that target.
+ backend_args: A dict mapping backend name to list of additional args to
+ use for that target.
+ **kwargs: Additional keyword arguments to pass to native.cc_test.
+ """
+ test_names = []
+ if not backends:
+ backends = all_backends
+
+ backends = [
+ backend
+ for backend in backends
+ if backend not in blacklisted_backends
+ ]
native.cc_library(
- name = "%s_%s" % (name, backend),
+ name = "%s_lib" % name,
srcs = srcs,
+ copts = copts,
testonly = True,
- hdrs = hdrs,
- copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()]
- + this_backend_copts,
- deps = deps + backend_deps,
+ deps = deps + ["//tensorflow/compiler/xla/tests:test_macros_header"],
)
-
-def generate_backend_suites(backends=[]):
- if not backends:
- backends = all_backends
- for backend in filter_backends(backends):
- native.test_suite(name="%s_tests" % backend,
- tags = ["xla_%s" % backend])
-
-
-def generate_backend_test_macros(backends=[]):
- if not backends:
- backends = all_backends
- for backend in filter_backends(backends):
- manifest = ""
- if backend in plugins:
- manifest = plugins[backend]["disabled_manifest"]
-
- native.cc_library(
- name="test_macros_%s" % backend,
- testonly = True,
- srcs = ["test_macros.cc"],
- hdrs = ["test_macros.h"],
- copts = [
- "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(),
- "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
- ],
- deps = [
- "//tensorflow/compiler/xla:types",
- "//tensorflow/core:lib",
- "//tensorflow/core:regexp_internal",
- "//tensorflow/core:test",
- ])
+ for backend in filter_backends(backends):
+ test_name = "%s_%s" % (name, backend)
+ this_backend_tags = ["xla_%s" % backend]
+ this_backend_copts = []
+ this_backend_args = backend_args.get(backend, [])
+ this_backend_data = []
+ if backend == "cpu":
+ backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
+ backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
+ elif backend == "gpu":
+ backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"]
+ backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"]
+ this_backend_tags += tf_cuda_tests_tags()
+ elif backend in plugins:
+ backend_deps = []
+ backend_deps += plugins[backend]["deps"]
+ this_backend_copts += plugins[backend]["copts"]
+ this_backend_tags += plugins[backend]["tags"]
+ this_backend_args += plugins[backend]["args"]
+ this_backend_data += plugins[backend]["data"]
+ else:
+ fail("Unknown backend %s" % backend)
+
+ if xla_test_library_deps:
+ for lib_dep in xla_test_library_deps:
+ backend_deps += ["%s_%s" % (lib_dep, backend)]
+
+ tf_cc_test(
+ name = test_name,
+ srcs = srcs,
+ tags = tags + backend_tags.get(backend, []) + this_backend_tags,
+ extra_copts = copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
+ this_backend_copts,
+ args = args + this_backend_args,
+ deps = deps + backend_deps,
+ data = data + this_backend_data,
+ **kwargs
+ )
+
+ test_names.append(test_name)
+
+ native.test_suite(name = name, tests = test_names)
+
+def xla_test_library(
+ name,
+ srcs,
+ hdrs = [],
+ deps = [],
+ backends = []):
+ """Generates cc_library targets for the given XLA backends.
+
+ This rule forces the sources to be compiled for each backend so that the
+ backend specific macros could expand correctly. It's useful when test targets
+ in different directories referring to the same sources but test with different
+ arguments.
+
+ Examples:
+
+ # Generates the targets: foo_test_library_cpu and foo_test_gpu.
+ xla_test_library(
+ name = "foo_test_library",
+ srcs = ["foo_test.cc"],
+ backends = ["cpu", "gpu"],
+ deps = [...],
+ )
+ # Then use the xla_test rule to generate test targets:
+ xla_test(
+ name = "foo_test",
+ srcs = [],
+ backends = ["cpu", "gpu"],
+ deps = [...],
+ xla_test_library_deps = [":foo_test_library"],
+ )
+
+ Args:
+ name: Name of the target.
+ srcs: Sources for the target.
+ hdrs: Headers for the target.
+ deps: Dependencies of the target.
+ backends: A list of backends to generate libraries for.
+ Supported values: "cpu", "gpu". If this list is empty, the
+ library will be generated for all supported backends.
+ """
+
+ if not backends:
+ backends = all_backends
+
+ for backend in filter_backends(backends):
+ this_backend_copts = []
+ if backend in ["cpu", "gpu"]:
+ backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend]
+ elif backend in plugins:
+ backend_deps = plugins[backend]["deps"]
+ this_backend_copts += plugins[backend]["copts"]
+ else:
+ fail("Unknown backend %s" % backend)
+
+ native.cc_library(
+ name = "%s_%s" % (name, backend),
+ srcs = srcs,
+ testonly = True,
+ hdrs = hdrs,
+ copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
+ this_backend_copts,
+ deps = deps + backend_deps,
+ )
+
+def generate_backend_suites(backends = []):
+ if not backends:
+ backends = all_backends
+ for backend in filter_backends(backends):
+ native.test_suite(
+ name = "%s_tests" % backend,
+ tags = ["xla_%s" % backend, "-broken", "manual"],
+ )
+
+def generate_backend_test_macros(backends = []):
+ if not backends:
+ backends = all_backends
+ for backend in filter_backends(backends):
+ manifest = ""
+ if backend in plugins:
+ manifest = plugins[backend]["disabled_manifest"]
+
+ native.cc_library(
+ name = "test_macros_%s" % backend,
+ testonly = True,
+ srcs = ["test_macros.cc"],
+ hdrs = ["test_macros.h"],
+ copts = [
+ "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(),
+ "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
+ ],
+ deps = [
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:regexp_internal",
+ "//tensorflow/core:test",
+ ],
+ )
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
index 8f86c528d0..8bd0a729b7 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
@@ -21,64 +21,68 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/test.h"
namespace xla {
-HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive,
- bool allow_mixed_precision)
- : HloTestBase(
- /*verifier_layout_sensitive=*/layout_sensitive,
- /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision) {}
-
-HloVerifiedTestBase::~HloVerifiedTestBase() {
- // We can't call the ASSERT or EXPECT test macros in destructors, so we
- // perform HLO verification in TearDown, and use the CHECK here to ensure
- // users don't accidentally override the verification.
- CHECK(tear_down_called_)
- << "TearDown was never called; subclasses of HloVerifiedTestBase that "
- << "override TearDown must call the superclass TearDown.";
-}
-
-void HloVerifiedTestBase::TearDown() {
- EXPECT_FALSE(tear_down_called_)
- << "TearDown called more than once; it should be called exactly once.";
- tear_down_called_ = true;
- if (module_) {
- VerifyModule(module_.get());
+Status VerifiedHloModule::Verify() {
+ if (computation_count() == 0) {
+ // The computation was never built. Nothing to verify.
+ return Status::OK();
}
- for (int i = 0; i < modules_.size(); ++i) {
- VerifyModule(modules_.at(i).get());
- }
- HloTestBase::TearDown();
+ return verifier_.Run(this).status();
}
-void HloVerifiedTestBase::VerifyModule(HloModule* module) {
- xla::StatusOr<bool> mutated = verifier().Run(module);
- if (!mutated.ok()) {
- ADD_FAILURE() << "HloVerifier failed: " << mutated.status();
- } else {
- EXPECT_FALSE(mutated.ValueOrDie())
- << "HloVerifier should never mutate the HloModule";
+void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
+ Status status = Verify();
+ if (!status.ok()) {
+ ADD_FAILURE() << "HloVerifier failed on module " << name()
+ << (message.empty() ? "" : absl::StrCat(" (", message, ")"))
+ << ": " << status;
}
}
+HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive,
+ bool allow_mixed_precision)
+ : HloTestBase(
+ /*verifier_layout_sensitive=*/layout_sensitive,
+ /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision),
+ verifier_layout_sensitive_(layout_sensitive),
+ allow_mixed_precision_in_hlo_verifier_(allow_mixed_precision) {}
+
HloModule& HloVerifiedTestBase::module() {
if (!module_) {
- module_ = HloTestBase::CreateNewModule();
+ module_ = CreateNewVerifiedModule(TestName());
}
return *module_;
}
HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) {
- modules_.emplace_back(HloTestBase::CreateNewModule());
+ modules_.emplace_back(CreateNewVerifiedModule(name));
return modules_.back().get();
}
void HloVerifiedTestBase::ParseAndVerifyModule(absl::string_view hlo_text,
const HloModuleConfig& config) {
CHECK(!module_) << "Called ParseModule when test already has a module.";
- TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config));
- VerifyModule(module_.get());
+ module_ = CreateNewVerifiedModule(TestName());
+ TF_CHECK_OK(ParseHloString(hlo_text, module_.get()));
+ module_->VerifyOrAddFailure("after parsing");
}
+
+StatusOr<std::unique_ptr<VerifiedHloModule>>
+HloVerifiedTestBase::ParseAndReturnVerifiedModule(
+ absl::string_view hlo_text, const HloModuleConfig& config) {
+ auto module = CreateNewVerifiedModule(TestName());
+ TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
+ TF_RETURN_IF_ERROR(module->Verify());
+ return std::move(module);
+}
+
+std::unique_ptr<VerifiedHloModule> HloVerifiedTestBase::CreateNewVerifiedModule(
+ const string& name) {
+ return absl::make_unique<VerifiedHloModule>(
+ name, GetModuleConfigForTest(), verifier_layout_sensitive_,
+ allow_mixed_precision_in_hlo_verifier_);
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
index 8fbc4fa753..388a99bb36 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
@@ -20,53 +20,84 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/base/macros.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
-// A base class for HLO tests that stores a default HloModule, and automatically
-// performs verification on that module on tear-down.
+// An HLO module derived class which verifies itself on destruction. This class
+// is intended to be used in unit tests. Any verification errors are raised via
+// ADD_FAILURE.
+class VerifiedHloModule : public HloModule {
+ public:
+ VerifiedHloModule(const string& name, const HloModuleConfig& config,
+ bool verifier_layout_sensitive,
+ bool allow_mixed_precision_in_hlo_verifier)
+ : HloModule(name, config),
+ verifier_(verifier_layout_sensitive,
+ allow_mixed_precision_in_hlo_verifier) {}
+
+ ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
+
+ // Verifies the module using HloVerifier and returns the status.
+ Status Verify();
+
+ // Verifies the module and flags any error with ADD_FAILURE. 'message' is
+ // included in the failure message.
+ void VerifyOrAddFailure(const string& message);
+
+ private:
+ HloVerifier verifier_;
+};
+
+// A base class for HLO tests that stores a default VerifiedHloModule.
class HloVerifiedTestBase : public HloTestBase {
protected:
- explicit HloVerifiedTestBase(bool layout_sensitive = false,
- bool allow_mixed_precision = false);
- ~HloVerifiedTestBase() override;
+ HloVerifiedTestBase(bool layout_sensitive = false,
+ bool allow_mixed_precision = false);
// Constructs a default shape verifier.
std::unique_ptr<ShapeVerifier> MakeShapeVerifier();
- // Performs verification on the default HloModule returned by module().
- // Automatically called by the testing framework for each test.
- //
- // REQUIRED: subclasses that override TearDown() must call this explicitly.
- void TearDown() override;
-
// Returns the default HloModule, lazily creating it if necessary via
// HloTestBase::CreateNewModule().
+ ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.")
HloModule& module();
+
+ ABSL_DEPRECATED("Use ParseAndReturnVerifiedModule() instead.")
void ParseAndVerifyModule(absl::string_view hlo_text,
const HloModuleConfig& config = HloModuleConfig());
+ // Parses the given string and returns module as a VerifiedHloModule.
+ StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
+ absl::string_view hlo_text,
+ const HloModuleConfig& config = HloModuleConfig());
+
// Creates a new module for a test, and stores it in modules_ so it can be
// verified. Intentionally hides HloTestBase::CreateNewModule, to prevent
// creation of unverified modules.
+ ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.")
HloModule* CreateNewModule(const string& name = TestName());
- private:
- void VerifyModule(HloModule* module);
+ // Creates and returns a verified HLO module with the given name.
+ std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
+ const string& name = TestName());
+ private:
// It is confusing to store modules created by module() and CreateNewModule()
// in different fields, but it allows us to migrate tests to
// HloVerifiedTestBase more easily, so it's a win because we can verify more
// modules. See b/80488902.
//
// Lazily populated. Access via module().
- std::unique_ptr<HloModule> module_;
+ std::unique_ptr<VerifiedHloModule> module_;
+
// Populated by calls to CreateNewModule.
- std::vector<std::unique_ptr<HloModule>> modules_;
+ std::vector<std::unique_ptr<VerifiedHloModule>> modules_;
- bool tear_down_called_ = false;
+ bool verifier_layout_sensitive_;
+ bool allow_mixed_precision_in_hlo_verifier_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc
new file mode 100644
index 0000000000..5c0263e811
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc
@@ -0,0 +1,158 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_verifier.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+// This class includes unit tests which are expected to fail because invalid HLO
+// modules are intentionally built. Unfortunately, Tensorflow doesn't appear to
+// include the necessary gunit parts to test this test machinery (needs the
+// macro EXPECT_NONFATAL_FAILURE). The disabled tests can be run with the
+// disabled tests enabled and failures can be manually compared against
+// expectations.
+class HloVerifiedTestBaseTest : public HloVerifiedTestBase {};
+
+XLA_TEST_F(HloVerifiedTestBaseTest, NoModule) {
+ // Test shouldn't fail if no module is created at all.
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, GoodLazilyCreatedModule) {
+ // Use module() to lazily create an empty module, build it up, and verify no
+ // failures.
+ HloModule& hlo_module = module();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ hlo_module.AddEntryComputation(builder.Build());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadLazilyCreatedModule) {
+ // Use module() to lazily create an empty module and build up an invalid
+ // module.
+ HloModule& hlo_module = module();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ hlo_module.AddEntryComputation(builder.Build());
+
+ *hlo_module.entry_computation()->root_instruction()->mutable_shape() =
+ ShapeUtil::MakeShape(PRED, {1, 2, 3});
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, GoodCreateNewModule) {
+ // Call CreateNewModule and build up a valid module.
+ HloModule* module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ module->AddEntryComputation(builder.Build());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadCreateNewModule) {
+ // Call CreateNewModule and build up a invalid module.
+ HloModule* module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ module->AddEntryComputation(builder.Build());
+
+ *module->entry_computation()->root_instruction()->mutable_shape() =
+ ShapeUtil::MakeShape(PRED, {1, 2, 3});
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndVerifyModuleGood) {
+ const char* const hlo_string = R"(
+HloModule ParseAndVerifyModuleGood
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x,y)
+}
+)";
+
+ ParseAndVerifyModule(hlo_string);
+ EXPECT_EQ(module().entry_computation()->instruction_count(), 3);
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleGood) {
+ const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleGood
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x,y)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ EXPECT_EQ(module->entry_computation()->instruction_count(), 3);
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleInvalidText) {
+ const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleGood
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x,y)
+}
+
+RANDOM GARBAGE
+)";
+
+ ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_ParseAndReturnVerifiedModuleBad) {
+ const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleBad
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[1234] add(x,y)
+}
+)";
+
+ ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status());
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc
new file mode 100644
index 0000000000..c530591c6e
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc
@@ -0,0 +1,120 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/synchronization/mutex.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+StatusOr<XlaComputation> BuildComputation() {
+ XlaBuilder b("computation");
+ Shape scalar_s32 = ShapeUtil::MakeShape(S32, {});
+ XlaOp infeed = InfeedWithToken(CreateToken(&b), scalar_s32);
+ return b.Build(
+ OutfeedWithToken(GetTupleElement(infeed, 0) +
+ ConstantLiteral(&b, LiteralUtil::CreateR0<int32>(1)),
+ GetTupleElement(infeed, 1), scalar_s32, ""));
+}
+
+void CompileAndExecute(
+ LocalExecutable* executable, int device_ordinal, LocalClient* client,
+ absl::Mutex* results_mutex,
+ std::vector<std::pair<int, StatusOr<ScopedShapedBuffer>>>* results) {
+ xla::ExecutableRunOptions execute_options;
+ execute_options.set_intra_op_thread_pool(
+ client->backend().eigen_intra_op_thread_pool_device());
+ execute_options.set_device_ordinal(device_ordinal);
+ execute_options.set_allocator(
+ xla::ClientLibrary::GetXlaService(client->platform())
+ ->backend()
+ .memory_allocator());
+ StatusOr<ScopedShapedBuffer> result = executable->Run({}, execute_options);
+ {
+ absl::MutexLock lock(results_mutex);
+ results->emplace_back(device_ordinal, std::move(result));
+ }
+}
+
+void TestWithDeviceCount(const int device_count) {
+ // Run `device_count` copies of the XLA program built by BuildComputation.
+ TF_ASSERT_OK_AND_ASSIGN(
+ se::Platform* const platform,
+ perftools::gputools::MultiPlatformManager::PlatformWithName("Host"));
+ xla::LocalClientOptions client_options;
+ client_options.set_platform(platform);
+ TF_ASSERT_OK_AND_ASSIGN(
+ LocalClient* const client,
+ xla::ClientLibrary::GetOrCreateLocalClient(client_options));
+
+ TF_ASSERT_OK_AND_ASSIGN(XlaComputation xla_computation, BuildComputation());
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<LocalExecutable> executable,
+ client->Compile(xla_computation, {}, xla::ExecutableBuildOptions{}));
+ std::vector<tensorflow::Thread*> threads;
+ absl::Mutex results_mutex;
+ std::vector<std::pair<int, StatusOr<ScopedShapedBuffer>>> results;
+ tensorflow::Env* env = tensorflow::Env::Default();
+ for (int device_ordinal = 0; device_ordinal < device_count;
+ device_ordinal++) {
+ tensorflow::Thread* t = env->StartThread(
+ tensorflow::ThreadOptions{}, absl::StrCat("thread-", device_ordinal),
+ [&executable, device_ordinal, client, &results_mutex, &results] {
+ CompileAndExecute(executable.get(), device_ordinal, client,
+ &results_mutex, &results);
+ });
+ threads.push_back(t);
+ }
+
+ for (int device_ordinal = 0; device_ordinal < device_count;
+ device_ordinal++) {
+ TF_ASSERT_OK(client->TransferToInfeedLocal(
+ LiteralUtil::CreateR0<int32>(device_ordinal * 100), device_ordinal));
+ }
+
+ for (int device_ordinal = 0; device_ordinal < device_count;
+ device_ordinal++) {
+ TF_ASSERT_OK_AND_ASSIGN(Literal outfeed,
+ client->TransferFromOutfeedLocal(
+ ShapeUtil::MakeShape(S32, {}), device_ordinal));
+ EXPECT_EQ(outfeed, LiteralUtil::CreateR0<int32>(device_ordinal * 100 + 1));
+ }
+
+ for (int device_ordinal = 0; device_ordinal < device_count;
+ device_ordinal++) {
+ delete threads[device_ordinal];
+ }
+
+ for (int device_ordinal = 0; device_ordinal < device_count;
+ device_ordinal++) {
+ TF_ASSERT_OK(results[device_ordinal].second.status());
+ }
+}
+
+// NB! This test requires --xla_force_host_platform_device_count=4
+
+TEST(MultipleDeviceOnHostTest, OneDevice) { TestWithDeviceCount(1); }
+
+TEST(MultipleDeviceOnHostTest, TwoDevices) { TestWithDeviceCount(2); }
+
+TEST(MultipleDeviceOnHostTest, ThreeDevices) { TestWithDeviceCount(3); }
+
+TEST(MultipleDeviceOnHostTest, FourDevices) { TestWithDeviceCount(4); }
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 63491a90bf..c25ccafaf8 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -1303,11 +1303,19 @@ struct R1ReduceWindowTestData {
/*pad_high=*/{0},
/*reducer=*/Reducer::kAdd},
+ // The pattern generated by inclusive scan (cumsum/cumprod).
{/*base_bounds=*/{4096}, /*window_bounds=*/{4096},
/*strides=*/{1},
/*pad_low=*/{4095},
/*pad_high=*/{0},
/*reducer=*/Reducer::kMax},
+
+ // The pattern generated by exclusive scan (cumsum/cumprod).
+ {/*base_bounds=*/{4096}, /*window_bounds=*/{4096},
+ /*strides=*/{1},
+ /*pad_low=*/{4096},
+ /*pad_high=*/{0},
+ /*reducer=*/Reducer::kMax},
};
string R1ReduceWindowTestDataToString(
diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc
index d20dba028a..b21dd56045 100644
--- a/tensorflow/compiler/xla/tests/scatter_test.cc
+++ b/tensorflow/compiler/xla/tests/scatter_test.cc
@@ -507,6 +507,36 @@ ENTRY main {
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
+XLA_TEST_F(ScatterTest, OutOfBoundsUpdateWindow) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatterNd_OobUpdateWindow
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3,2] parameter(0)
+ indices = s32[1,2] parameter(1)
+ updates = s32[1,2,2] parameter(2)
+ ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ Literal operand =
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}});
+ Literal updates = LiteralUtil::CreateR3<int32>({{{-10, 10}, {-40, 40}}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
+}
+
XLA_TEST_F(ScatterTest, OneScalarIndex) {
const char* hlo_text = R"(
HloModule OneScalarIndex
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index a40c2d7de6..2cc33ab096 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -412,6 +412,7 @@ INSTANTIATE_TEST_CASE_P(
R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{7, 11}}, {{0, 1}}}, //
R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{1, 0}}}, //
R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{0, 1}}}, //
+ R2Spec{8672, 512, {{8, 0}}, {{8672, 512}}, {{542, 1}}, {{1, 0}}}, //
R2Spec{
511, 513, {{129, 300}}, {{400, 500}}, {{101, 129}}, {{1, 0}}}, //
R2Spec{
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index 7abd8651d5..8b1b9e1519 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -763,9 +763,7 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
-// Test while nodes that share the while body computation.
-// TODO(b/37245345): Fails on GPU backend.
-TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) {
+TEST_F(WhileTest, WhileLoopsWithSharedBodyAndInit) {
std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
ShapeUtil::MakeShape(F32, {10})};
Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index b53f89d63b..60d25a6407 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -200,6 +200,15 @@ message DebugOptions {
// among different algorithms.
bool xla_gpu_crash_on_verification_failures = 101;
+ // Force the host platform to pretend that there are these many host
+ // "devices". All these devices are backed by the same threadpool. Defaults
+ // to 1.
+ //
+ // Setting this to anything other than 1 can increase overhead from context
+ // switching but we let the user override this behavior to help run tests on
+ // the host that run models in parallel across multiple devices.
+ int32 xla_force_host_platform_device_count = 102;
+
// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
map<string, string> xla_backend_extra_options = 500;
diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD
index 09ab4ed95f..b6dcfc4eb9 100644
--- a/tensorflow/compiler/xrt/tests/BUILD
+++ b/tensorflow/compiler/xrt/tests/BUILD
@@ -8,6 +8,10 @@ package(
)
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
cc_library(
name = "raw_api_test_lib",
@@ -57,7 +61,7 @@ tf_cuda_cc_test(
size = "medium",
srcs = [],
args = ["--xla_test_device=XLA_GPU"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
":raw_api_test_lib",
"//tensorflow/compiler/jit:xla_gpu_device",
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index d98a24994c..98dff965a9 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -60,7 +60,6 @@ py_library(
"//tensorflow/contrib/learn",
"//tensorflow/contrib/legacy_seq2seq:seq2seq_py",
"//tensorflow/contrib/libsvm",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/contrib/linear_optimizer:sdca_estimator_py",
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
"//tensorflow/contrib/lite/python:lite",
@@ -113,25 +112,18 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python/estimator:estimator_py",
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + select({
- "//tensorflow:with_kafka_support_windows_override": [],
- "//tensorflow:with_kafka_support": [
+ "//tensorflow:linux_s390x": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
+ "//tensorflow/contrib/bigtable",
+ "//tensorflow/contrib/cloud:cloud_py",
+ "//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols
"//tensorflow/contrib/kafka",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_aws_support_windows_override": [],
- "//tensorflow:with_aws_support": [
"//tensorflow/contrib/kinesis",
+ "//tensorflow/contrib/tensorrt:init_py",
+ "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
],
- "//conditions:default": [],
- }) + if_not_windows_cuda([
- "//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols
- ]) + if_not_windows([
- "//tensorflow/contrib/bigtable", # depends on bigtable
- "//tensorflow/contrib/cloud:cloud_py", # doesn't compile on Windows
- "//tensorflow/contrib/tensorrt:init_py", # doesn't compile on windows
- "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
- ]),
+ }),
)
cc_library(
@@ -140,7 +132,6 @@ cc_library(
deps = [
"//tensorflow/contrib/boosted_trees:boosted_trees_kernels",
"//tensorflow/contrib/coder:all_kernels",
- "//tensorflow/contrib/data/kernels:dataset_kernels",
"//tensorflow/contrib/factorization/kernels:all_kernels",
"//tensorflow/contrib/hadoop:dataset_kernels",
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels",
@@ -155,20 +146,14 @@ cc_library(
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_cuda([
"//tensorflow/contrib/nccl:nccl_kernels",
]) + select({
- "//tensorflow:with_kafka_support_windows_override": [],
- "//tensorflow:with_kafka_support": [
+ "//tensorflow:linux_s390x": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
"//tensorflow/contrib/kafka:dataset_kernels",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_aws_support_windows_override": [],
- "//tensorflow:with_aws_support": [
"//tensorflow/contrib/kinesis:dataset_kernels",
+ "//tensorflow/contrib/tensorrt:trt_engine_op_kernel",
],
- "//conditions:default": [],
- }) + if_not_windows([
- "//tensorflow/contrib/tensorrt:trt_engine_op_kernel",
- ]),
+ }),
)
cc_library(
@@ -177,8 +162,6 @@ cc_library(
deps = [
"//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib",
"//tensorflow/contrib/coder:all_ops",
- "//tensorflow/contrib/data:dataset_ops_op_lib",
- "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
"//tensorflow/contrib/factorization:all_ops",
"//tensorflow/contrib/framework:all_ops",
"//tensorflow/contrib/hadoop:dataset_ops_op_lib",
@@ -194,18 +177,12 @@ cc_library(
"//tensorflow/contrib/text:all_ops",
"//tensorflow/contrib/tpu:all_ops",
] + select({
- "//tensorflow:with_kafka_support_windows_override": [],
- "//tensorflow:with_kafka_support": [
+ "//tensorflow:linux_s390x": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
"//tensorflow/contrib/kafka:dataset_ops_op_lib",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_aws_support_windows_override": [],
- "//tensorflow:with_aws_support": [
"//tensorflow/contrib/kinesis:dataset_ops_op_lib",
+ "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib",
],
- "//conditions:default": [],
- }) + if_not_windows([
- "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib",
- ]),
+ }),
)
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index 9478e42b46..e71b0e0ae3 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -63,7 +63,6 @@ from tensorflow.contrib import labeled_tensor
from tensorflow.contrib import layers
from tensorflow.contrib import learn
from tensorflow.contrib import legacy_seq2seq
-from tensorflow.contrib import linalg
from tensorflow.contrib import linear_optimizer
from tensorflow.contrib import lookup
from tensorflow.contrib import losses
diff --git a/tensorflow/contrib/all_reduce/python/all_reduce_test.py b/tensorflow/contrib/all_reduce/python/all_reduce_test.py
index b3f5d92259..9a8f62b986 100644
--- a/tensorflow/contrib/all_reduce/python/all_reduce_test.py
+++ b/tensorflow/contrib/all_reduce/python/all_reduce_test.py
@@ -149,7 +149,7 @@ class AllReduceTest(test_util.TensorFlowTestCase):
num_devices = num_workers * num_gpus
dev_list = ["/replica:0/task:0/device:CPU:0"
for _ in range(num_devices)]
- with self.test_session():
+ with self.cached_session():
input_tensors = self._buildInitialVars(shape, dev_list)
un_op = lambda x: math_ops.div(
x, constant_op.constant(num_devices, dtype=types_pb2.DT_FLOAT))
diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md
index 6ea2db72c4..8c277b59e8 100644
--- a/tensorflow/contrib/autograph/README.md
+++ b/tensorflow/contrib/autograph/README.md
@@ -4,147 +4,6 @@
[deprecated](https://github.com/tensorflow/community/pull/18), AutoGraph is
moving into TensorFlow core.
-The new code location is `tensorflow/python/autograph`.
+The new code location is `tensorflow/python/autograph`. Please refer to the
+README.md file in that directory.
**
-
-IMPORTANT: AutoGraph is beta software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)).
-
-AutoGraph is a Python to TensorFlow compiler.
-
-With AutoGraph, you can write [Eager style](https://www.tensorflow.org/guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops. [Please see this file for which parts of the Python language we currently support](LIMITATIONS.md).
-
-For example, this Python function:
-
-```
-def f(x):
- if x < 0:
- x = -x
- return x
-```
-
-would be converted to this:
-
-```
-def graph_mode_f(x):
- with tf.name_scope('f'):
-
- def if_true():
- with tf.name_scope('if_true'):
- x_1, = x,
- x_1 = tf.negative(x_1)
- return x_1,
-
- def if_false():
- with tf.name_scope('if_false'):
- x_1, = x,
- return x_1,
- x = ag__.utils.run_cond(tf.greater(x, 0), if_true, if_false)
- return x
-```
-
-so you can use it like an op:
-
-```
-with tf.Graph().as_default():
- x = tf.constant(-1.0)
-
- converted_f = autograph.to_graph(f)
- y = converted_f(x)
-
- with tf.Session() as sess:
- print(sess.run(y))
- # Output: 1
-```
-
-# Getting started
-
-Use AutoGraph in one of the following ways, described below:
-
- 1. Annotations (simpler)
- 2. Functional API (more flexible)
-
-To get started, install the latest nightly TensorFlow build:
-
-```shell
-pip install -U tf-nightly
-```
-
-Then import the `autograph` module from `tf.contrib`:
-
-```
-from tensorflow.contrib import autograph as ag
-```
-
-### Related links
-
-Articles:
-
- * [TensorFlow blog post](https://medium.com/tensorflow/autograph-converts-python-into-tensorflow-graphs-b2a871f87ec7)
-
-Interactive notebooks:
-
- * [Quick guide](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/guide/autograph.ipynb)
- * [RNN trained using Keras and Estimators](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb)
- * [Demo from the TF Dev Summit 2018](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb)
- * [Basic control flow speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_collatz_speed_test.ipynb)
- * [MNIST training speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_mnist_speed_test.ipynb)
- * [Basic algorithm samples](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb)
- * [Introductory workshop support notebook](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb)
-
-## Using with annotations
-
-Annotating a function or class with `@convert` converts it in place:
-
-```
-@ag.convert()
-def f(x):
- if x < 0:
- x = -x
- return x
-```
-
-... so that it always outputs TensorFlow code:
-
-```
-with tf.Graph().as_default():
- x = tf.constant(-1)
-
- y = f(x)
-
- with tf.Session() as sess:
- print(sess.run(y))
- # Output: 1
-```
-
-## Using the functional API
-
-The functional API allows you to convert an existing function, class or object after it was defined:
-
-```
-converted_f = ag.to_graph(f)
-
-print(converted_f(tf.constant(-1)))
-# Output: Tensor
-
-print(f(-1))
-# Output: 1
-```
-
-You can use the functional API to inspect the generated code as well:
-
-```
-print(ag.to_code(f))
-# Output: <Python and TensorFlow code>
-```
-
-## Filing bugs and feature requests
-
-### Reporting a bug
-
- - If AutoGraph-generated code is compiling and running, but producing an incorrect result, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message.
- - If AutoGraph-generated code is compiling, but not running, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message.
- - If AutoGraph-generated code is not compiling, send us two minimal pieces of code. First, the Eager code that you would like to write, and second, the Graph code that you would like AutoGraph to have generated for you.
-
-### Requesting a feature
-
-If you’d like AutoGraph to convert a feature of Python or TF that we currently don’t handle, please let us know by filing a bug. We’ll make it as easy as possible to interact with us through there.
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
index 7846814546..01ee8703a9 100644
--- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py
+++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
@@ -43,7 +43,7 @@ class BatchOpsTest(test.TestCase):
def testBasicBatch(self):
"""Tests that a single batched tensor executes together and only once."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
@@ -83,7 +83,7 @@ class BatchOpsTest(test.TestCase):
def testBatchWithPadding(self):
"""Test that batching with padding up to an allowed batch size works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
@@ -113,7 +113,7 @@ class BatchOpsTest(test.TestCase):
def testMultipleBatch(self):
"""Tests that multiple batched tensors execute together."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, _, _ = batch_ops.batch(
@@ -152,7 +152,7 @@ class BatchOpsTest(test.TestCase):
def testIllegalBatchDifferentDim0Sizes(self):
"""Tests illegally feeding tensors with different dim0 sizes."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
@@ -166,7 +166,7 @@ class BatchOpsTest(test.TestCase):
def testBasicUnbatch(self):
"""Tests that batch and unbatch work together."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
@@ -190,7 +190,8 @@ class BatchOpsTest(test.TestCase):
def testBasicUnbatchV1Decorated(self):
"""Tests that the batch_function_v1 decorator works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
+
@batch_ops.batch_function_v1(1, 10, 100000)
def computation(in_t):
return in_t + 1
@@ -211,7 +212,7 @@ class BatchOpsTest(test.TestCase):
def testBasicUnbatchDecorated(self):
"""Tests that the batch_function decorator works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# TODO(apassos): Removing this line causes test flakiness! Ideally should
# be investigated.
default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable
@@ -236,7 +237,7 @@ class BatchOpsTest(test.TestCase):
def testBatchDecoratedWithCapturedInput(self):
"""Tests that the batch_function decorator works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
@@ -260,7 +261,7 @@ class BatchOpsTest(test.TestCase):
def testBatchFunctionOp(self):
"""Tests that the batch_function op works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
@function.Defun(dtypes.int32)
def computation(in_t):
@@ -289,7 +290,7 @@ class BatchOpsTest(test.TestCase):
def testBatchFunctionOpWithCapturedInput(self):
"""Tests that batch_function op works with captured input."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@@ -323,7 +324,7 @@ class BatchOpsTest(test.TestCase):
def testBatchFunctionOpWithInputError(self):
"""Tests that batch_function op works with error in the inputs."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@function.Defun(dtypes.int32, dtypes.int32)
@@ -346,7 +347,7 @@ class BatchOpsTest(test.TestCase):
def testBasicUnbatchDecoratedWithReshape(self):
"""Tests that the batch_function decorator works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
@@ -368,7 +369,7 @@ class BatchOpsTest(test.TestCase):
def testUnbatchTimeout(self):
"""Tests that the unbatch timeout works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
@@ -410,7 +411,7 @@ class BatchOpsTest(test.TestCase):
def testUnbatchGrad(self):
"""Tests that batch and unbatch are differentiable."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
index 9e6a146f67..13215ffabf 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
@@ -42,7 +42,7 @@ class ExpectationImportanceSampleTest(test.TestCase):
def test_normal_integral_mean_and_var_correctly_estimated(self):
n = int(1e6)
- with self.test_session():
+ with self.cached_session():
mu_p = constant_op.constant([-1.0, 1.0], dtype=dtypes.float64)
mu_q = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
sigma_p = constant_op.constant([0.5, 0.5], dtype=dtypes.float64)
@@ -72,7 +72,7 @@ class ExpectationImportanceSampleTest(test.TestCase):
# Test that importance sampling can correctly estimate the probability that
# the product of components in a MultivariateNormal are > 0.
n = 1000
- with self.test_session():
+ with self.cached_session():
p = mvn_diag_lib.MultivariateNormalDiag(
loc=[0.], scale_diag=[1.0, 1.0])
q = mvn_diag_lib.MultivariateNormalDiag(
@@ -99,7 +99,7 @@ class ExpectationImportanceSampleLogspaceTest(test.TestCase):
def test_normal_distribution_second_moment_estimated_correctly(self):
# Test the importance sampled estimate against an analytical result.
n = int(1e6)
- with self.test_session():
+ with self.cached_session():
mu_p = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
mu_q = constant_op.constant([-1.0, 1.0], dtype=dtypes.float64)
sigma_p = constant_op.constant([1.0, 2 / 3.], dtype=dtypes.float64)
@@ -127,7 +127,7 @@ class GetSamplesTest(test.TestCase):
"""Test the private method 'get_samples'."""
def test_raises_if_both_z_and_n_are_none(self):
- with self.test_session():
+ with self.cached_session():
dist = normal_lib.Normal(loc=0., scale=1.)
z = None
n = None
@@ -136,7 +136,7 @@ class GetSamplesTest(test.TestCase):
_get_samples(dist, z, n, seed)
def test_raises_if_both_z_and_n_are_not_none(self):
- with self.test_session():
+ with self.cached_session():
dist = normal_lib.Normal(loc=0., scale=1.)
z = dist.sample(seed=42)
n = 1
@@ -145,7 +145,7 @@ class GetSamplesTest(test.TestCase):
_get_samples(dist, z, n, seed)
def test_returns_n_samples_if_n_provided(self):
- with self.test_session():
+ with self.cached_session():
dist = normal_lib.Normal(loc=0., scale=1.)
z = None
n = 10
@@ -154,7 +154,7 @@ class GetSamplesTest(test.TestCase):
self.assertEqual((10,), z.get_shape())
def test_returns_z_if_z_provided(self):
- with self.test_session():
+ with self.cached_session():
dist = normal_lib.Normal(loc=0., scale=1.)
z = dist.sample(10, seed=42)
n = None
@@ -166,7 +166,7 @@ class GetSamplesTest(test.TestCase):
class ExpectationTest(test.TestCase):
def test_works_correctly(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = constant_op.constant([-1e6, -100, -10, -1, 1, 10, 100, 1e6])
p = normal_lib.Normal(loc=x, scale=1.)
@@ -213,7 +213,7 @@ class ExpectationTest(test.TestCase):
rtol=0.05, atol=0.)
def test_docstring_example_normal(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_draws = int(1e5)
mu_p = constant_op.constant(0.)
mu_q = constant_op.constant(1.)
diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
index 9afe3df585..18d40fc1df 100644
--- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
+from tensorflow.python.util import deprecation
__all__ = [
'expectation',
@@ -66,7 +67,7 @@ def expectation_importance_sampler(f,
shape broadcastable to `q.batch_shape`.
For example, `log_p` works "just like" `sampling_dist_q.log_prob`.
sampling_dist_q: The sampling distribution.
- `tf.contrib.distributions.Distribution`.
+ `tfp.distributions.Distribution`.
`float64` `dtype` recommended.
`log_p` and `q` should be supported on the same set.
z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
@@ -141,7 +142,7 @@ def expectation_importance_sampler_logspace(
shape broadcastable to `q.batch_shape`.
For example, `log_p` works "just like" `q.log_prob`.
sampling_dist_q: The sampling distribution.
- `tf.contrib.distributions.Distribution`.
+ `tfp.distributions.Distribution`.
`float64` `dtype` recommended.
`log_p` and `q` should be supported on the same set.
z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
@@ -188,6 +189,12 @@ def _logspace_mean(log_values):
return log_mean_of_values
+@deprecation.deprecated(
+ '2018-10-01',
+ 'The tf.contrib.bayesflow library has moved to '
+ 'TensorFlow Probability (https://github.com/tensorflow/probability). '
+ 'Use `tfp.monte_carlo.expectation` instead.',
+ warn_once=True)
def expectation(f, samples, log_prob=None, use_reparametrization=True,
axis=0, keep_dims=False, name=None):
r"""Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\).
@@ -236,17 +243,17 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True,
Example Use:
```python
- bf = tf.contrib.bayesflow
- ds = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Monte-Carlo approximation of a reparameterized distribution, e.g., Normal.
num_draws = int(1e5)
- p = ds.Normal(loc=0., scale=1.)
- q = ds.Normal(loc=1., scale=2.)
- exact_kl_normal_normal = ds.kl_divergence(p, q)
+ p = tfd.Normal(loc=0., scale=1.)
+ q = tfd.Normal(loc=1., scale=2.)
+ exact_kl_normal_normal = tfd.kl_divergence(p, q)
# ==> 0.44314718
- approx_kl_normal_normal = bf.expectation(
+ approx_kl_normal_normal = tfp.monte_carlo.expectation(
f=lambda x: p.log_prob(x) - q.log_prob(x),
samples=p.sample(num_draws, seed=42),
log_prob=p.log_prob,
@@ -260,9 +267,9 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True,
num_draws = int(1e5)
p = ds.Gamma(concentration=1., rate=1.)
q = ds.Gamma(concentration=2., rate=3.)
- exact_kl_gamma_gamma = ds.kl_divergence(p, q)
+ exact_kl_gamma_gamma = tfd.kl_divergence(p, q)
# ==> 0.37999129
- approx_kl_gamma_gamma = bf.expectation(
+ approx_kl_gamma_gamma = tfp.monte_carlo.expectation(
f=lambda x: p.log_prob(x) - q.log_prob(x),
samples=p.sample(num_draws, seed=42),
log_prob=p.log_prob,
@@ -278,7 +285,7 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True,
KL-divergence, the following is preferred:
```python
- approx_kl_p_q = bf.monte_carlo_csiszar_f_divergence(
+ approx_kl_p_q = tfp.vi.monte_carlo_csiszar_f_divergence(
f=bf.kl_reverse,
p_log_prob=q.log_prob,
q=p,
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
index 11f530e82a..2c6317157d 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
@@ -28,6 +28,7 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
DatasetBase** output) override {
BigtableTableResource* table;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table));
+ core::ScopedUnref scoped_unref(table);
std::vector<string> column_families;
std::vector<string> columns;
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
index 5cab729d9c..92a3658667 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
@@ -31,6 +31,7 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
BigtableTableResource* resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+ core::ScopedUnref scoped_unref(resource);
*output = new Dataset(ctx, resource, std::move(prefix));
}
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
index 4dc4647bd2..bd8805a382 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
@@ -34,6 +34,7 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
BigtableTableResource* resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+ core::ScopedUnref scoped_unref(resource);
*output =
new Dataset(ctx, resource, std::move(start_key), std::move(end_key));
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
index 736775bdac..01608dc6bc 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
@@ -38,6 +38,7 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
BigtableTableResource* resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+ core::ScopedUnref scoped_unref(resource);
OP_REQUIRES(ctx, prefix.empty() || start_key.empty(),
errors::InvalidArgument(
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
index 208b7b3e08..9b60e0a667 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
@@ -28,6 +28,7 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
BigtableTableResource* resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+ core::ScopedUnref scoped_unref(resource);
*output = new Dataset(ctx, resource);
}
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
index 9407855fe8..688289a4e2 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
@@ -67,6 +67,7 @@ class BigtableScanDatasetOp : public DatasetOpKernel {
BigtableTableResource* resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+ core::ScopedUnref scoped_unref(resource);
const uint64 num_outputs = columns.size() + 1;
std::vector<PartialTensorShape> output_shapes;
diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
index e36f7f32c6..316da9ebe1 100644
--- a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
+++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
@@ -61,7 +61,7 @@ class BigtableOpsTest(test.TestCase):
n = itr.get_next()
expected = list(self.COMMON_ROW_KEYS)
expected.reverse()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
for i in range(3):
@@ -84,7 +84,7 @@ class BigtableOpsTest(test.TestCase):
expected_keys.reverse()
expected_values = list(self.COMMON_VALUES)
expected_values.reverse()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
for i in range(3):
@@ -125,7 +125,7 @@ class BigtableOpsTest(test.TestCase):
expected_keys = list(self.COMMON_ROW_KEYS)
expected_values = list(self.COMMON_VALUES)
expected_tuples = zip(expected_keys, expected_values)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
for i, elem in enumerate(expected_tuples):
@@ -144,7 +144,7 @@ class BigtableOpsTest(test.TestCase):
itr = ds.make_initializable_iterator()
n = itr.get_next()
expected_key = self.COMMON_ROW_KEYS[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
output = sess.run(n)
@@ -163,7 +163,7 @@ class BigtableOpsTest(test.TestCase):
def runSampleKeyPairsTest(self, ds, expected_key_pairs):
itr = ds.make_initializable_iterator()
n = itr.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
for i, elems in enumerate(expected_key_pairs):
@@ -219,7 +219,7 @@ class BigtableOpsTest(test.TestCase):
ds = bigtable_api._BigtableSampleKeyPairsDataset(
self._table, prefix="r", start="r1", end="")
itr = ds.make_initializable_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(itr.initializer)
@@ -227,7 +227,7 @@ class BigtableOpsTest(test.TestCase):
ds = bigtable_api._BigtableSampleKeyPairsDataset(
self._table, prefix="r", start="", end="r3")
itr = ds.make_initializable_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(itr.initializer)
@@ -235,7 +235,7 @@ class BigtableOpsTest(test.TestCase):
ds = self._table.parallel_scan_prefix(prefix="r", cf1="c1")
itr = ds.make_initializable_iterator()
n = itr.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES))
@@ -253,7 +253,7 @@ class BigtableOpsTest(test.TestCase):
ds = self._table.parallel_scan_range(start="r1", end="r4", cf1="c1")
itr = ds.make_initializable_iterator()
n = itr.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES))
diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
index 3e1b622867..cf56822ff4 100644
--- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
+++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
@@ -575,7 +575,7 @@ def _normalize_columns(columns, provided_kwargs):
return normalized
-class _BigtableKeyDataset(dataset_ops.Dataset):
+class _BigtableKeyDataset(dataset_ops.DatasetSource):
"""_BigtableKeyDataset is an abstract class representing the keys of a table.
"""
@@ -645,7 +645,7 @@ class _BigtableSampleKeysDataset(_BigtableKeyDataset):
table=self._table._resource) # pylint: disable=protected-access
-class _BigtableLookupDataset(dataset_ops.Dataset):
+class _BigtableLookupDataset(dataset_ops.DatasetSource):
"""_BigtableLookupDataset represents a dataset that retrieves values for keys.
"""
@@ -678,7 +678,7 @@ class _BigtableLookupDataset(dataset_ops.Dataset):
columns=self._columns)
-class _BigtableScanDataset(dataset_ops.Dataset):
+class _BigtableScanDataset(dataset_ops.DatasetSource):
"""_BigtableScanDataset represents a dataset that retrieves keys and values.
"""
@@ -715,7 +715,7 @@ class _BigtableScanDataset(dataset_ops.Dataset):
probability=self._probability)
-class _BigtableSampleKeyPairsDataset(dataset_ops.Dataset):
+class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource):
"""_BigtableSampleKeyPairsDataset returns key pairs from a Bigtable table.
"""
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
index 5fcb19a47a..14b6fc4ac2 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
@@ -173,6 +173,7 @@ py_library(
py_test(
name = "dnn_tree_combined_estimator_test",
size = "medium",
+ timeout = "long",
srcs = ["dnn_tree_combined_estimator_test.py"],
srcs_version = "PY2AND3",
tags = [
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
index 78232fa0a6..48f12a64f9 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
@@ -51,6 +51,7 @@ def make_custom_export_strategy(name,
feature_columns: A list of feature columns.
export_input_fn: A function that takes no arguments and returns an
`InputFnOps`.
+ use_core_columns: A boolean, whether core feature columns were used.
Returns:
An `ExportStrategy`.
@@ -196,7 +197,7 @@ def convert_to_universal_format(dtec, sorted_feature_names,
matching_id.int64_value = split.feature_id
node.custom_left_child_test.Pack(categorical_test)
else:
- raise ValueError("Unexpected node type %s", node_type)
+ raise ValueError("Unexpected node type %s" % node_type)
node.left_child_id.value = split.left_id
node.right_child_id.value = split.right_id
return model_and_features
@@ -236,7 +237,7 @@ def _get_feature_importances(dtec, feature_names, num_dense_floats,
assert tree_node.node_metadata.gain == 0
continue
else:
- raise ValueError("Unexpected split type %s", node_type)
+ raise ValueError("Unexpected split type %s" % node_type)
# Apply shrinkage factor. It is important since it is not always uniform
# across different trees.
sums[split_column] += (
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
index 839eedd3a8..83a8dee632 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
@@ -188,7 +188,8 @@ class CoreDNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase):
# Train for a few steps.
est.train(input_fn=_train_input_fn, steps=1000)
- # 10 steps for dnn, 3 for 1 tree of depth 3 + 1 after the tree finished
+ # 10 steps for dnn + 3 for 1 tree of depth 3 + 1 after the tree finished
+ # + 1 for resource variables.
self._assert_checkpoint(est.model_dir, global_step=14)
res = est.evaluate(input_fn=_eval_input_fn, steps=1)
self.assertLess(0.5, res["auc"])
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
index c155128c0e..d7b14e00ba 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -238,8 +238,8 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
output_leaf_index=False)
classifier.fit(input_fn=_train_input_fn, steps=15)
- # When no override of global steps, 5 steps were used.
- self._assert_checkpoint(classifier.model_dir, global_step=5)
+ # When no override of global steps, 6 steps were used.
+ self._assert_checkpoint(classifier.model_dir, global_step=6)
def testOverridesGlobalSteps(self):
learner_config = learner_pb2.LearnerConfig()
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index 51e0c2e431..8edb5d6c64 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -579,13 +579,6 @@ class BuildSparseInequalitySplitsOp : public OpKernel {
const int end_index =
partition_boundaries[non_empty_partitions[root_idx]][j + 1]
.start_index;
- CHECK(bucket_ids_and_dimensions(start_index, 1) ==
- bucket_ids_and_dimensions(end_index - 1, 1))
- << "For bucket " << bucket_ids_and_dimensions(start_index, 0)
- << " the dimension was "
- << bucket_ids_and_dimensions(start_index, 1) << " and for "
- << bucket_ids_and_dimensions(end_index - 1, 0) << " "
- << bucket_ids_and_dimensions(end_index - 1, 1);
if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id) {
// 0-dimension case which has a first bucket for catch all feature.
CHECK(bucket_ids_and_dimensions(start_index, 1) == 0)
@@ -746,21 +739,22 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
// Find the number of unique partitions before we allocate the output.
std::vector<int32> partition_boundaries;
- std::vector<int32> non_empty_partitions;
- for (int i = 0; i < partition_ids.size() - 1; ++i) {
+ partition_boundaries.push_back(0);
+ for (int i = 1; i < partition_ids.size(); ++i) {
// Make sure the input is sorted by partition_ids;
- CHECK_LE(partition_ids(i), partition_ids(i + 1));
- if (i == 0 || partition_ids(i) != partition_ids(i - 1)) {
+ OP_REQUIRES(context, partition_ids(i - 1) <= partition_ids(i),
+ errors::InvalidArgument("Partition IDs must be sorted."));
+ if (partition_ids(i) != partition_ids(i - 1)) {
partition_boundaries.push_back(i);
- // Some partitions might only have bias feature. We don't want to split
- // those so check that the partition has at least 2 features.
- if (partition_ids(i) == partition_ids(i + 1)) {
- non_empty_partitions.push_back(partition_boundaries.size() - 1);
- }
}
}
- if (partition_ids.size() > 0) {
- partition_boundaries.push_back(partition_ids.size());
+ std::vector<int32> non_empty_partitions;
+ partition_boundaries.push_back(partition_ids.size());
+ for (int i = 0; i < partition_boundaries.size() - 1; ++i) {
+ // We want to ignore partitions with only the bias term.
+ if (partition_boundaries[i + 1] - partition_boundaries[i] >= 2) {
+ non_empty_partitions.push_back(i);
+ }
}
int num_elements = non_empty_partitions.size();
Tensor* output_partition_ids_t = nullptr;
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
index 94ea7bc2eb..a2f708081a 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
@@ -170,7 +170,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(1, split_node.feature_id)
def testObliviousFeatureSplitGeneration(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Feature ID |
# i0 | (0.2, 0.12) | 1 | 1 |
@@ -577,6 +577,92 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(len(gains), 0)
self.assertEqual(len(splits), 0)
+ def testLastOneEmpty(self):
+ with self.cached_session() as sess:
+ # The data looks like the following:
+ # Example | Gradients | Partition | Feature ID |
+ # i0 | (0.2, 0.12) | 0 | 1,2 |
+ # i1 | (-0.5, 0.07) | 0 | |
+ # i2 | (1.2, 0.2) | 0 | 2 |
+ # i3 | (4.0, 0.13) | 1 | |
+ gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
+ hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
+ partition_ids = [0, 0, 0, 1]
+ indices = [[0, 0], [0, 1], [2, 0]]
+ values = array_ops.constant([1, 2, 2], dtype=dtypes.int64)
+
+ gradient_shape = tensor_shape.scalar()
+ hessian_shape = tensor_shape.scalar()
+ class_id = -1
+
+ split_handler = categorical_split_handler.EqualitySplitHandler(
+ l1_regularization=0.1,
+ l2_regularization=1,
+ tree_complexity_regularization=0,
+ min_node_weight=0,
+ sparse_int_column=sparse_tensor.SparseTensor(indices, values, [4, 1]),
+ feature_column_group_id=0,
+ gradient_shape=gradient_shape,
+ hessian_shape=hessian_shape,
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ init_stamp_token=0)
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ empty_gradients, empty_hessians = get_empty_tensors(
+ gradient_shape, hessian_shape)
+ example_weights = array_ops.ones([4, 1], dtypes.float32)
+
+ update_1 = split_handler.update_stats_sync(
+ 0,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ with ops.control_dependencies([update_1]):
+ are_splits_ready, partitions, gains, splits = (
+ split_handler.make_splits(0, 1, class_id))
+ are_splits_ready, partitions, gains, splits = (
+ sess.run([are_splits_ready, partitions, gains, splits]))
+ self.assertTrue(are_splits_ready)
+ self.assertAllEqual([0], partitions)
+
+ # Check the split on partition 0.
+ # -(0.2 + 1.2 - 0.1) / (0.12 + 0.2 + 1)
+ expected_left_weight = -0.9848484848484846
+
+ # (0.2 + 1.2 - 0.1) ** 2 / (0.12 + 0.2 + 1)
+ expected_left_gain = 1.2803030303030298
+
+ # -(-0.5 + 0.1) / (0.07 + 1)
+ expected_right_weight = 0.37383177570093457
+
+ # (-0.5 + 0.1) ** 2 / (0.07 + 1)
+ expected_right_gain = 0.14953271028037385
+
+ # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
+ expected_bias_gain = 0.46043165467625885
+
+ split_info = split_info_pb2.SplitInfo()
+ split_info.ParseFromString(splits[0])
+ left_child = split_info.left_child.vector
+ right_child = split_info.right_child.vector
+ split_node = split_info.split_node.categorical_id_binary_split
+
+ self.assertEqual(0, split_node.feature_column)
+
+ self.assertEqual(2, split_node.feature_id)
+
+ self.assertAllClose(
+ expected_left_gain + expected_right_gain - expected_bias_gain, gains[0],
+ 0.00001)
+
+ self.assertAllClose([expected_left_weight], left_child.value, 0.00001)
+
+ self.assertAllClose([expected_right_weight], right_child.value, 0.00001)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index c7eb2493a8..8531e97f90 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -402,13 +402,13 @@ class GradientBoostedDecisionTreeModel(object):
self._feature_columns = feature_columns
self._learner_config_serialized = learner_config.SerializeToString()
self._num_quantiles = num_quantiles
- self._max_tree_depth = variables.Variable(
+ self._max_tree_depth = variables.VariableV1(
initial_value=self._learner_config.constraints.max_tree_depth)
- self._attempted_trees = variables.Variable(
+ self._attempted_trees = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
trainable=False,
name="attempted_trees")
- self._finalized_trees = variables.Variable(
+ self._finalized_trees = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
trainable=False,
name="finalized_trees")
@@ -770,28 +770,28 @@ class GradientBoostedDecisionTreeModel(object):
fc_name_idx += 1
# Create ensemble stats variables.
- num_layer_examples = variables.Variable(
+ num_layer_examples = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="num_layer_examples",
trainable=False)
- num_layer_steps = variables.Variable(
+ num_layer_steps = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="num_layer_steps",
trainable=False)
- num_layers = variables.Variable(
+ num_layers = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="num_layers",
trainable=False)
- active_tree = variables.Variable(
+ active_tree = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="active_tree",
trainable=False)
- active_layer = variables.Variable(
+ active_layer = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="active_layer",
trainable=False)
# Variable that becomes false once bias centering is done.
- continue_centering = variables.Variable(
+ continue_centering = variables.VariableV1(
initial_value=self._center_bias,
name="continue_centering",
trainable=False)
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
index 73e41bc457..6d20a2e7f4 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
@@ -86,7 +86,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testExtractFeatures(self):
"""Tests feature extraction."""
- with self.test_session():
+ with self.cached_session():
features = {}
features["dense_float"] = array_ops.zeros([2, 1], dtypes.float32)
features["sparse_float"] = sparse_tensor.SparseTensor(
@@ -128,7 +128,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testExtractFeaturesWithTransformation(self):
"""Tests feature extraction."""
- with self.test_session():
+ with self.cached_session():
features = {}
features["dense_float"] = array_ops.zeros([2, 1], dtypes.float32)
features["sparse_float"] = sparse_tensor.SparseTensor(
@@ -178,7 +178,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testExtractFeaturesFromCoreFeatureColumns(self):
"""Tests feature extraction when using core columns."""
- with self.test_session():
+ with self.cached_session():
features = {}
# Sparse float column does not exist in core, so only dense numeric and
# categorical.
@@ -213,7 +213,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnChiefNoBiasCentering(self):
"""Tests the train function running on chief without bias centering."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
@@ -239,7 +239,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -316,7 +316,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
self.assertProtoEquals(expected_tree, output.trees[0])
def testObliviousDecisionTreeAsWeakLearner(self):
- with self.test_session():
+ with self.cached_session():
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
@@ -473,7 +473,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnChiefSparseAndDense(self):
"""Tests the train function with sparse and dense features."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
@@ -503,7 +503,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -580,7 +580,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnChiefScalingNumberOfExamples(self):
"""Tests the train function running on chief without bias centering."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
@@ -607,7 +607,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -685,7 +685,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnChiefWithBiasCentering(self):
"""Tests the train function running on chief with bias centering."""
- with self.test_session():
+ with self.cached_session():
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
@@ -711,7 +711,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -757,7 +757,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnNonChiefNoBiasCentering(self):
"""Tests the train function running on worker without bias centering."""
- with self.test_session():
+ with self.cached_session():
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
@@ -783,7 +783,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -821,7 +821,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnNonChiefWithCentering(self):
"""Tests the train function running on worker with bias centering."""
- with self.test_session():
+ with self.cached_session():
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
@@ -847,7 +847,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -885,7 +885,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testPredictFn(self):
"""Tests the predict function."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create ensemble with one bias node.
ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
text_format.Merge(
@@ -939,7 +939,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testPredictFnWithLeafIndexAdvancedLeft(self):
"""Tests the predict function with output leaf ids."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create ensemble with one bias node.
ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
text_format.Merge(
@@ -1051,7 +1051,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnMulticlassFullHessian(self):
"""Tests the GBDT train for multiclass full hessian."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
@@ -1090,7 +1090,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
weights = array_ops.ones([batch_size, 1], dtypes.float32)
partition_ids = array_ops.zeros([batch_size], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1155,7 +1155,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnMulticlassDiagonalHessian(self):
"""Tests the GBDT train for multiclass diagonal hessian."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
@@ -1194,7 +1194,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
weights = array_ops.ones([batch_size, 1], dtypes.float32)
partition_ids = array_ops.zeros([batch_size], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1259,7 +1259,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnMulticlassTreePerClass(self):
"""Tests the GBDT train for multiclass tree per class strategy."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
@@ -1299,7 +1299,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
weights = array_ops.ones([batch_size, 1], dtypes.float32)
partition_ids = array_ops.zeros([batch_size], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1374,7 +1374,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnChiefFeatureSelectionReachedLimitNoGoodSplit(self):
"""Tests the train function running on chief with feature selection."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
@@ -1405,7 +1405,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1493,7 +1493,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnChiefFeatureSelectionWithGoodSplits(self):
"""Tests the train function running on chief with feature selection."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
@@ -1524,7 +1524,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1610,7 +1610,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnChiefFeatureSelectionReachedLimitIncrementAttemptedLayer(self):
"""Tests the train function running on chief with feature selection."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
tree = tree_ensemble_config.trees.add()
@@ -1656,7 +1656,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1720,7 +1720,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testResetModelBeforeAndAfterSplit(self):
"""Tests whether resetting works."""
- with self.test_session():
+ with self.cached_session():
# First build a small tree and train it to verify training works.
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
@@ -1854,7 +1854,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testResetModelNonChief(self):
"""Tests the reset function on a non-chief worker."""
- with self.test_session():
+ with self.cached_session():
# Create ensemble with one bias node.
ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
text_format.Merge(
@@ -1930,7 +1930,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testResetModelWithCenterBias(self):
"""Tests the reset function running on chief with bias centering."""
- with self.test_session():
+ with self.cached_session():
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses_test.py b/tensorflow/contrib/boosted_trees/python/utils/losses_test.py
index ccb8509c03..cc22504c8f 100644
--- a/tensorflow/contrib/boosted_trees/python/utils/losses_test.py
+++ b/tensorflow/contrib/boosted_trees/python/utils/losses_test.py
@@ -45,7 +45,7 @@ class LossesTest(test_util.TensorFlowTestCase):
eps = 0.2
- with self.test_session():
+ with self.cached_session():
predictions_tensor = constant_op.constant(
prediction_logits, dtype=dtypes.float32)
loss_for_positives, _ = losses.per_example_exp_loss(
@@ -84,7 +84,7 @@ class LossesTest(test_util.TensorFlowTestCase):
predictions = np.array(
[[0.123], [23.2], [233], [52], [3]], dtype=np.float32)
- with self.test_session():
+ with self.cached_session():
loss_tensor, _ = losses.per_example_squared_loss(labels, weights,
predictions)
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index ebcabb4223..c6d6f04168 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -353,7 +353,7 @@ endif()
# MKL Support
if (tensorflow_ENABLE_MKL_SUPPORT)
- add_definitions(-DINTEL_MKL -DEIGEN_USE_VML)
+ add_definitions(-DINTEL_MKL -DEIGEN_USE_VML -DENABLE_MKL)
include(mkl)
list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkl_STATIC_LIBRARIES})
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkl_copy_shared_to_destination)
diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md
index 789dab81ed..77242b34fd 100644
--- a/tensorflow/contrib/cmake/README.md
+++ b/tensorflow/contrib/cmake/README.md
@@ -17,7 +17,7 @@ Linux.
Current Status
--------------
-CMake can be used to build TensorFlow on Windows. See the [getting started documentation](https://www.tensorflow.org/install/install_windows)
+CMake can be used to build TensorFlow on Windows. See the [getting started documentation](https://www.tensorflow.org/install/source_windows)
for instructions on how to install a pre-built TensorFlow package on Windows.
### Current known limitations
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index fb871acae9..2975b167ec 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -132,7 +132,6 @@ tensorflow/contrib/cudnn_rnn/python
tensorflow/contrib/cudnn_rnn/python/layers
tensorflow/contrib/cudnn_rnn/python/ops
tensorflow/contrib/data
-tensorflow/contrib/data/kernels
tensorflow/contrib/data/python
tensorflow/contrib/data/python/kernel_tests
tensorflow/contrib/data/python/kernel_tests/serialization
@@ -273,9 +272,6 @@ tensorflow/contrib/libsvm
tensorflow/contrib/libsvm/python
tensorflow/contrib/libsvm/python/kernel_tests
tensorflow/contrib/libsvm/python/ops
-tensorflow/contrib/linalg
-tensorflow/contrib/linalg/python
-tensorflow/contrib/linalg/python/ops
tensorflow/contrib/linear_optimizer
tensorflow/contrib/linear_optimizer/kernels
tensorflow/contrib/linear_optimizer/kernels/g3doc
@@ -409,7 +405,6 @@ tensorflow/contrib/summary
tensorflow/contrib/tensorboard
tensorflow/contrib/tensorboard/plugins
tensorflow/contrib/tensorboard/plugins/projector
-tensorflow/contrib/tensorboard/plugins/trace
# TODO(sami): Add cmake implementations.
# tensorflow/contrib/tensorrt/python
# tensorflow/contrib/tensorrt/python/ops
diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt
index cf1ee2ad76..42afbd9105 100644
--- a/tensorflow/contrib/cmake/python_protos.txt
+++ b/tensorflow/contrib/cmake/python_protos.txt
@@ -12,7 +12,6 @@ tensorflow/contrib/mpi_collectives
tensorflow/contrib/session_bundle
tensorflow/contrib/tensor_forest/proto
tensorflow/contrib/tensorboard/plugins/projector
-tensorflow/contrib/tensorboard/plugins/trace
tensorflow/contrib/tpu/proto
tensorflow/contrib/tpu/profiler
tensorflow/contrib/training/python/training
diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake
index 2c878c1716..ed31351d9e 100644
--- a/tensorflow/contrib/cmake/tf_tests.cmake
+++ b/tensorflow/contrib/cmake/tf_tests.cmake
@@ -183,7 +183,6 @@ if (tensorflow_BUILD_PYTHON_TESTS)
file(GLOB_RECURSE tf_test_src_py
${tf_test_src_py}
"${tensorflow_source_dir}/tensorflow/contrib/legacy_seq2seq/*_test.py"
- "${tensorflow_source_dir}/tensorflow/contrib/linalg/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/graph_editor/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/bayesflow/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/framework/*_test.py"
diff --git a/tensorflow/contrib/coder/python/ops/coder_ops_test.py b/tensorflow/contrib/coder/python/ops/coder_ops_test.py
index d5e14e7a64..f5431ca1ff 100644
--- a/tensorflow/contrib/coder/python/ops/coder_ops_test.py
+++ b/tensorflow/contrib/coder/python/ops/coder_ops_test.py
@@ -45,7 +45,7 @@ class CoderOpsTest(test.TestCase):
decoded = coder_ops.range_decode(
encoded, array_ops.shape(data), cdf, precision=14)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(*sess.run((data, decoded)))
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD
index d7583be6d8..f83386b8a4 100644
--- a/tensorflow/contrib/compiler/BUILD
+++ b/tensorflow/contrib/compiler/BUILD
@@ -5,7 +5,10 @@ package(default_visibility = [":friends"])
package_group(
name = "friends",
includes = ["//tensorflow/compiler/jit:friends"],
- packages = ["//tensorflow/..."],
+ packages = [
+ "//tensorflow/...",
+ "//third_party/py/tensor2tensor/...",
+ ],
)
load("//tensorflow:tensorflow.bzl", "tf_py_test")
@@ -53,12 +56,16 @@ py_library(
srcs = ["xla.py"],
srcs_version = "PY2AND3",
deps = [
+ "//tensorflow/compiler/jit:xla_ops_py",
+ "//tensorflow/contrib/tpu:tpu_lib",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
+ "//tensorflow/python:summary_op_util",
"//tensorflow/python:util",
- "//tensorflow/python/estimator:model_fn",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/estimator:estimator_py",
],
)
diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py
index 42b3b9f026..3e631b5909 100644
--- a/tensorflow/contrib/compiler/jit_test.py
+++ b/tensorflow/contrib/compiler/jit_test.py
@@ -173,7 +173,7 @@ class JITTest(test.TestCase):
class CompilationEnabledInGradientTest(test.TestCase):
def testCompilationInGradient(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant([[3.]])
y_nc = math_ops.matmul(x, x, name="not_compiled")
with jit.experimental_jit_scope():
diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py
index 60f5af1662..873b03580d 100644
--- a/tensorflow/contrib/compiler/xla.py
+++ b/tensorflow/contrib/compiler/xla.py
@@ -12,20 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
-"""xla provides experimental xla support API."""
+"""xla is an experimental library that provides XLA support APIs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
+import contextlib
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.compiler.jit.ops import xla_ops
+from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import summary_op_util
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
+from tensorflow.python.util import function_utils
+from tensorflow.python.util import tf_decorator
_XLA_COMPILE_ATTR = '_xla_compile_id'
_MAX_WARNING_LINES = 5
@@ -51,6 +60,30 @@ _UNSUPPORTED_OPS = set([
])
+def compile(computation, inputs=None): # pylint: disable=redefined-builtin
+ """Builds an operator that compiles and runs `computation` with XLA.
+
+ Args:
+ computation: A Python function that builds a computation to apply to the
+ input. If the function takes n inputs, 'inputs' should be a list of n
+ tensors.
+
+ `computation` may return a list of operations and tensors. Tensors must
+ come before operations in the returned list. The return value of
+ `compile` is a list of tensors corresponding to the tensors from the
+ output of `computation`.
+
+ All `Operation`s returned from `computation` will be executed when
+ evaluating any of the returned output tensors.
+ inputs: A list of input tensors or `None` (equivalent to an empty list).
+
+ Returns:
+ A list of output tensors.
+ """
+ # pylint: disable=protected-access
+ return _compile_internal(computation, inputs)
+
+
class XLACompileContext(control_flow_ops.XLAControlFlowContext):
"""A `ControlFlowContext` for nodes inside an XLA computation cluster.
@@ -206,3 +239,409 @@ class XLACompileContext(control_flow_ops.XLAControlFlowContext):
if self.GetWhileContext():
return self.GetWhileContext().back_prop
return False
+
+
+def _compile_internal(computation, inputs=None):
+ """Builds graph operators that compiles and symbolically executes computation.
+
+ Args:
+ computation: A Python function that builds the computation to compile and
+ execute.
+ inputs: A list of input tensors or `None` (equivalent to `[]`). Its order
+ should match ordering of computation arguments.
+ Returns:
+ A list of output tensors from computation.
+ Raises:
+ ValueError: If any element in computation outputs is neither an operations
+ or a value that can be converted to tensor.
+ TypeError: If `inputs` is not a list or tuple.
+ """
+ if inputs is None:
+ inputs = []
+
+ if not isinstance(inputs, collections.Sequence):
+ raise TypeError('inputs must be a list')
+
+ # Converts inputs to Tensors.
+ inputs = [ops.convert_to_tensor(x) for x in inputs]
+ input_arity = len(inputs)
+
+ arg_error = tpu_function.check_function_argument_count(
+ computation, input_arity, infeed_queue=None)
+ if arg_error is not None:
+ raise TypeError(
+ 'Supplied computation cannot be called with the specified inputs. You '
+ 'specified %d inputs: %s, but the computation needs %s' %
+ (input_arity, str([i.name for i in inputs[0]]), arg_error))
+
+ cluster_name = ops.get_default_graph().unique_name('cluster')
+ pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
+ context = XLACompileContext(name=cluster_name, pivot=pivot)
+ try:
+ context.Enter()
+
+ # Add identity ops so even unused inputs are 'consumed' by the
+ # computation.
+ computation_inputs = [
+ array_ops.identity(x, name='input_{}'.format(i))
+ for i, x in enumerate(inputs)
+ ]
+
+ # Only resource variables work inside an XLA computation, so turn on
+ # resource variables for the computation.
+ vscope = variable_scope.get_variable_scope()
+ saved_use_resource = vscope.use_resource
+ vscope.set_use_resource(True)
+
+ with _disable_summary_context():
+ outputs = computation(*computation_inputs)
+
+ # Restore variable scope after computation.
+ vscope.set_use_resource(saved_use_resource)
+
+ # If the computation returns `None`, make it an empty tuple.
+ if outputs is None:
+ outputs = tuple()
+ # If the computation only returned one value, make it a tuple.
+ if not isinstance(outputs, collections.Sequence):
+ outputs = (outputs,)
+
+ # Append `no_op` here so that return value of this function always contains
+ # at least one op that can trigger XlaLaunch node.
+ outputs += (control_flow_ops.no_op(),)
+ try:
+ outputs = [
+ o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
+ for o in outputs
+ ]
+ except Exception as e:
+ raise ValueError(
+ 'XLA computation function return values must all either be Operations'
+ ' or convertible to Tensors. Got error: "%s"' % str(e))
+
+ # Separates the returned Operations and Tensors.
+ output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
+ output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
+
+ if outputs != output_tensors + output_operations:
+ raise ValueError(
+ 'XLA computation function must return zero or more Tensor values '
+ 'followed by zero or more Operations.')
+ output_arity = len(output_tensors)
+
+ new_output_tensors = []
+ for t in output_tensors:
+ with ops.device(t.device if t.device else ''):
+ new_output_tensors.append(array_ops.identity(t))
+
+ output_tensors = new_output_tensors
+ context.ExitResult(output_tensors)
+ finally:
+ context.report_unsupported_operations()
+ context.Exit()
+
+ outputs = [
+ xla_ops.xla_cluster_output(output_tensors[i], name='output{}'.format(i))
+ for i in xrange(output_arity)
+ ]
+
+ with ops.control_dependencies(output_operations):
+ if output_arity == 0:
+ # When XLA computation returns only operations and no tensors, a NoOp
+ # dependent on the operations in outputs is returned. Otherwise final
+ # outputs would be empty and there is no way to trigger returned
+ # operations.
+ return control_flow_ops.no_op(name='output_0')
+ else:
+ # Wraps the outputs in identity operators that carries control
+ # dependencies.
+ return [
+ array_ops.identity(outputs[i], name='output_%d' % i)
+ for i in xrange(output_arity)
+ ]
+
+
+@contextlib.contextmanager
+def _disable_summary_context():
+ """Enters a context where all summary ops are skipped.
+
+ Summaries are not yet supported in xla.compile(). So we provide this context
+ manager that can skip creating summary ops. This is a temporary workaround due
+ to XLA not supporting summary ops.
+
+ Yields:
+ None.
+ """
+ original_skip_summary_func = summary_op_util.skip_summary
+ summary_op_util.skip_summary = lambda: True
+
+ try:
+ yield
+ finally:
+ summary_op_util.skip_summary = original_skip_summary_func
+
+
+class _CapturedObject(object):
+ """A placeholder to capture an object."""
+
+ def __init__(self):
+ self._object = None
+
+ def capture(self, o):
+ if self._object:
+ raise RuntimeError(
+ 'InternalError: _CapturedObject can capture only once. Please file '
+ 'bug.')
+
+ self._object = o
+
+ def get(self):
+ return self._object
+
+
+def _get_scaffold(captured_scaffold_fn):
+ """Retrieves the Scaffold from `captured_scaffold_fn`."""
+ scaffold_fn = captured_scaffold_fn.get()
+
+ if not scaffold_fn:
+ return None
+
+ scaffold = scaffold_fn()
+ if scaffold is None:
+ raise ValueError(
+ 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed')
+
+ return scaffold
+
+
+class _ModelFnWrapper(object):
+ """_ModelFnWrapper supports executing model_fn with XLA."""
+
+ def __init__(self, function):
+ self._model_fn = function
+
+ def __call__(self, features, labels, mode, params):
+
+ # TPUEstimator compiles model_fn when use_tpu=True. To avoid double
+ # compilation, we use this params['use_tpu'] as a hint. When it is set to
+ # True, model_fn is called without compilation.
+ # Note that this condition isn't accurate for the case of exporting a model.
+ # In that case we should ideally not compile so that user can see detailed
+ # graph. However, we don't have enough information to tell whether model_fn
+ # is being called for export mode or not.
+ # TODO(ycao): Make this condition more accurate when implementing PREDICT
+ # mode.
+ if params.get('use_tpu'):
+ return self._call_model_fn(features, labels, mode, params)
+
+ if mode == model_fn_lib.ModeKeys.TRAIN:
+ train_step, captured_scaffold_fn = self._make_train_step(
+ features, labels, params)
+ (loss,) = compile(train_step)
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ loss=loss,
+ train_op=array_ops.identity(loss),
+ scaffold=_get_scaffold(captured_scaffold_fn))
+ elif mode == model_fn_lib.ModeKeys.EVAL:
+ eval_step, captured_eval_metric_fn, captured_scaffold_fn = (
+ self._make_eval_step(features, labels, params))
+ outputs = compile(eval_step)
+ loss = outputs[0]
+
+ # Calculate eval_metric_ops if eval_metric_fn is set and captured.
+ eval_metric_fn = captured_eval_metric_fn.get()
+ if eval_metric_fn:
+ eval_metric_fn_tensors = outputs[1:]
+ eval_metric_ops = eval_metric_fn(*eval_metric_fn_tensors)
+ else:
+ eval_metric_ops = None
+
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ loss=loss,
+ eval_metric_ops=eval_metric_ops,
+ scaffold=_get_scaffold(captured_scaffold_fn))
+ else:
+ raise NotImplementedError('%s is not implemented, only TRAIN and EVAL are'
+ ' supported' % mode)
+
+ def _make_train_step(self, features, labels, params):
+ """Creates a single step of training for xla.compile()."""
+ captured_scaffold_fn = _CapturedObject()
+
+ def train_step():
+ """A single step of training."""
+ estimator_spec = self._call_model_fn(features, labels,
+ model_fn_lib.ModeKeys.TRAIN, params)
+
+ try:
+ captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
+ except AttributeError:
+ captured_scaffold_fn.capture(None)
+
+ # train_step will be run by xla.compile(). xla.compile() only supports
+ # tensor output while train_op can be either an operation or a tensor.
+ # Even though xla.compile() automatically adds operation-typed train_op as
+ # control dependency of other tensor outputs, it doesn't do so for
+ # tensor-typed train_op. Thus, we need to set it explicitly here.
+ with ops.control_dependencies([estimator_spec.train_op]):
+ return array_ops.identity(estimator_spec.loss)
+
+ return train_step, captured_scaffold_fn
+
+ def _make_eval_step(self, features, labels, params):
+ """Creates a single step of evaluation for xla.compile()."""
+ captured_eval_metric_fn = _CapturedObject()
+ captured_scaffold_fn = _CapturedObject()
+
+ def eval_step():
+ """A single step of evaluation."""
+ estimator_spec = self._call_model_fn(features, labels,
+ model_fn_lib.ModeKeys.EVAL, params)
+
+ try:
+ captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
+ except AttributeError:
+ captured_scaffold_fn.capture(None)
+
+ eval_metric_fn = None
+ eval_metric_fn_tensors = []
+ try:
+ if estimator_spec.eval_metrics:
+ (eval_metric_fn, eval_metric_fn_tensors) = estimator_spec.eval_metrics
+ except AttributeError:
+ pass
+
+ # If a dictionary is provided, we need to convert it into a list sorted
+ # according to order of eval_metric_fn positional arguments.
+ if isinstance(eval_metric_fn_tensors, dict):
+ eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)
+ eval_metric_fn_tensors = [
+ eval_metric_fn_tensors[i] for i in eval_metric_fn_args
+ ]
+
+ captured_eval_metric_fn.capture(eval_metric_fn)
+
+ return tuple([estimator_spec.loss] + eval_metric_fn_tensors)
+
+ return eval_step, captured_eval_metric_fn, captured_scaffold_fn
+
+ def _call_model_fn(self, features, labels, mode, params):
+ """Calls the model_fn with required parameters."""
+ model_fn_args = function_utils.fn_args(self._model_fn)
+ kwargs = {}
+
+ if 'labels' in model_fn_args:
+ kwargs['labels'] = labels
+ elif labels is not None:
+ raise ValueError(
+ 'model_fn does not take labels, but input_fn returns labels.')
+ if 'mode' in model_fn_args:
+ kwargs['mode'] = mode
+
+ if 'params' in model_fn_args:
+ kwargs['params'] = params
+
+ return self._verify_estimator_spec(
+ self._model_fn(features=features, **kwargs))
+
+ def _verify_estimator_spec(self, estimator_spec):
+ """Verifies estimator spec contains correct data."""
+ # TODO(ycao): Implement estimator spec verification for other modes.
+
+ try:
+ if estimator_spec.scaffold:
+ logging.warning('EstimatorSpec.scaffold is ignored with XLA compilation'
+ '. Please use TPUEstimatorSpec.scaffold_fn instead.')
+ except AttributeError:
+ pass
+
+ try:
+ if estimator_spec.eval_metric_ops:
+ raise ValueError('EstimatorSpec.eval_metric_ops is not supported with '
+ 'XLA compilation. Please use '
+ 'TPUEstimatorSpec.eval_metrics instead.')
+ except AttributeError:
+ pass
+
+ if estimator_spec.mode == model_fn_lib.ModeKeys.EVAL:
+ # If estimator_spec is of type TPUEstimatorSpec and contains eval_metrics,
+ # check that eval_metrics contains eval_metric_fn and
+ # eval_metric_fn_tensors with matching arguments.
+ try:
+ eval_metrics = estimator_spec.eval_metrics
+ except AttributeError:
+ eval_metrics = None
+
+ if eval_metrics:
+ (eval_metric_fn, eval_metric_fn_tensors) = eval_metrics
+ eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)
+
+ if isinstance(eval_metric_fn_tensors, dict):
+ missing_tensors = [
+ i for i in eval_metric_fn_args if i not in eval_metric_fn_tensors
+ ]
+ additional_tensors = [
+ i for i in eval_metric_fn_tensors if i not in eval_metric_fn_args
+ ]
+
+ if missing_tensors:
+ raise ValueError('Arguments %s are needed by metric_fn (first '
+ 'element of TPUEstimatorSpec.eval_metrics) but '
+ 'they are not provided by evaluation tensors '
+ '(second element of TPUEstimatorSpec.eval_metrics)'
+ '.' % missing_tensors)
+
+ if additional_tensors:
+ raise ValueError('Arguments %s are provided by evaluation tensors '
+ '(second element of TPUEstimatorSpec.eval_metrics)'
+ ' but they are not needed by metric_fn (first '
+ 'element of TPUEstimatorSpec.eval_metrics).' %
+ additional_tensors)
+
+ return estimator_spec
+
+
+def estimator_model_fn(target_model_fn=None):
+ """estimator_model_fn decorates a model_fn to be compiled for execution.
+
+ Currently only it only works with `TPUEstimator`. If you need to use it with
+ base `Estimator`, please add `tf.enable_resource_variables()` at beginning of
+ your program.
+
+ Example 1, decorating model_fn:
+ ```
+ @xla.estimator_model_fn()
+ def model_fn(features, labels, mode, params):
+ ...
+ return EstimatorSpec(...)
+
+
+ est = Estimator(model_fn=model_fn, ...)
+ est.train(...)
+
+ ```
+
+ Example 2, decorator as function:
+ ```
+ def model_fn(features, labels, mode, params):
+ ...
+ return EstimatorSpec(...)
+
+ est = Estimator(model_fn=xla.estimator_model_fn(model_fn), ...)
+ est.train(...)
+ ```
+
+ Args:
+ target_model_fn: model_fn to be decorated. This is only needed when
+ decorator is used in function call form (example 2).
+
+ Returns:
+ Decorated target_model_fn.
+ """
+
+ def decorated(function):
+ return tf_decorator.make_decorator(function, _ModelFnWrapper(function))
+
+ return decorated(target_model_fn) if target_model_fn else decorated
diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py
index d1af15f7e4..67f8ac2b93 100644
--- a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py
+++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py
@@ -102,9 +102,9 @@ def _project_multipliers_wrt_euclidean_norm(multipliers, radius):
0.0,
(radius - standard_ops.reduce_sum(multipliers)) / standard_ops.maximum(
1.0, standard_ops.reduce_sum(inactive)))
- multipliers += scale * inactive
+ multipliers = multipliers + (scale * inactive)
new_inactive = standard_ops.cast(multipliers > 0, multipliers.dtype)
- multipliers *= new_inactive
+ multipliers = multipliers * new_inactive
return (iteration, multipliers, new_inactive, inactive)
iteration = standard_ops.constant(0)
diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py
index 2c673d9347..a6cb1f62f0 100644
--- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py
+++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py
@@ -175,9 +175,9 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix):
scale = (1.0 - standard_ops.reduce_sum(
matrix, axis=0, keepdims=True)) / standard_ops.maximum(
1.0, standard_ops.reduce_sum(inactive, axis=0, keepdims=True))
- matrix += scale * inactive
+ matrix = matrix + (scale * inactive)
new_inactive = standard_ops.cast(matrix > 0, matrix.dtype)
- matrix *= new_inactive
+ matrix = matrix * new_inactive
return (iteration, matrix, new_inactive, inactive)
iteration = standard_ops.constant(0)
@@ -210,8 +210,9 @@ def _project_log_stochastic_matrix_wrt_kl_divergence(log_matrix):
# For numerical reasons, make sure that the largest matrix element is zero
# before exponentiating.
- log_matrix -= standard_ops.reduce_max(log_matrix, axis=0, keepdims=True)
- log_matrix -= standard_ops.log(
+ log_matrix = log_matrix - standard_ops.reduce_max(
+ log_matrix, axis=0, keepdims=True)
+ log_matrix = log_matrix - standard_ops.log(
standard_ops.reduce_sum(
standard_ops.exp(log_matrix), axis=0, keepdims=True))
return log_matrix
diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
index 6c9ab6aeb8..9c5871da34 100644
--- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py
+++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
@@ -31,7 +31,7 @@ from __future__ import division
from __future__ import print_function
from copy import deepcopy
-from tensorflow.python.ops.variables import Variable
+from tensorflow.python.ops.variables import VariableV1
from tensorflow.python.client.session import Session
from tensorflow.python.framework import ops
@@ -55,7 +55,7 @@ def copy_variable_to_graph(org_instance, to_graph, scope=''):
TypeError: If `org_instance` is not a `Variable`.
"""
- if not isinstance(org_instance, Variable):
+ if not isinstance(org_instance, VariableV1):
raise TypeError(str(org_instance) + ' is not a Variable')
#The name of the new variable
@@ -88,7 +88,7 @@ def copy_variable_to_graph(org_instance, to_graph, scope=''):
#Initialize the new variable
with to_graph.as_default():
- new_var = Variable(
+ new_var = VariableV1(
init_value,
trainable,
name=new_name,
diff --git a/tensorflow/contrib/copy_graph/python/util/copy_test.py b/tensorflow/contrib/copy_graph/python/util/copy_test.py
index 05744bec4e..ba97c78456 100644
--- a/tensorflow/contrib/copy_graph/python/util/copy_test.py
+++ b/tensorflow/contrib/copy_graph/python/util/copy_test.py
@@ -36,7 +36,7 @@ class CopyVariablesTest(test.TestCase):
with graph1.as_default():
#Define a Variable in graph1
- some_var = variables.Variable(2)
+ some_var = variables.VariableV1(2)
#Initialize session
sess1 = session_lib.Session()
#Initialize the Variable
@@ -72,7 +72,7 @@ class CopyOpsTest(test.TestCase):
with graph1.as_default():
#Initialize a basic expression y = ax + b
x = array_ops.placeholder("float")
- a = variables.Variable(3.0)
+ a = variables.VariableV1(3.0)
b = constant_op.constant(4.0)
ax = math_ops.multiply(x, a)
y = math_ops.add(ax, b)
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
index 5a667485be..c59d3682d4 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
@@ -413,6 +413,31 @@ class CudnnRNNTestParamsSize(TensorFlowTestCase):
self._testOneLSTMParamsSize(num_layers, num_units, input_size,
direction)
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testLSTMParamsSizeShape(self):
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ model = _CreateModel(
+ cudnn_rnn_ops.CUDNN_LSTM,
+ constant_op.constant([4]), 200, 200,
+ direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+ params_size = model.params_size()
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ model = _CreateModel(
+ cudnn_rnn_ops.CUDNN_LSTM,
+ 4, constant_op.constant([200]), 200,
+ direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+ params_size = model.params_size()
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ model = _CreateModel(
+ cudnn_rnn_ops.CUDNN_LSTM,
+ 4, 200, constant_op.constant([200]),
+ direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+ params_size = model.params_size()
+
class CudnnRNNTestInference(TensorFlowTestCase):
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
index fda1b9f1b3..57793a8ff5 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
@@ -460,7 +460,7 @@ class CudnnRNNTestBasic(test_util.TensorFlowTestCase):
grad, = gradients.gradients(
math_ops.reduce_sum(accumulation), (original_input,))
init_op = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
accumulation_eval, grad_eval = sess.run((accumulation, grad))
self.assertAllEqual([28, 100, 100], accumulation_eval.shape)
diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD
index 9f710613dd..38f1c65a4d 100644
--- a/tensorflow/contrib/data/BUILD
+++ b/tensorflow/contrib/data/BUILD
@@ -4,17 +4,6 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-load(
- "//tensorflow:tensorflow.bzl",
- "tf_custom_op_library",
- "tf_gen_op_libs",
- "if_not_windows",
-)
-load(
- "//tensorflow/core:platform/default/build_config_root.bzl",
- "if_static",
-)
-
py_library(
name = "data",
srcs = ["__init__.py"],
@@ -25,30 +14,3 @@ py_library(
"//tensorflow/python:util",
],
)
-
-cc_library(
- name = "lib_proto_parsing_for_dataset_ops",
- deps = if_not_windows(["//tensorflow/core:lib_proto_parsing"]),
-)
-
-tf_custom_op_library(
- name = "_dataset_ops.so",
- srcs = [
- "ops/dataset_ops.cc",
- "ops/indexed_dataset_ops.cc",
- ],
- deps = [
- "//tensorflow/contrib/data/kernels:dataset_kernels",
- "//tensorflow/contrib/data/kernels:indexed_dataset",
- ] + if_static(
- extra_deps = [":lib_proto_parsing_for_dataset_ops"],
- otherwise = [],
- ),
-)
-
-tf_gen_op_libs(
- op_lib_names = [
- "dataset_ops",
- "indexed_dataset_ops",
- ],
-)
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index c378b1ce8d..3cb51279c3 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -44,6 +44,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@group_by_reducer
@@group_by_window
@@ignore_errors
+@@latency_stats
@@make_batched_features_dataset
@@make_csv_dataset
@@make_saveable_from_iterator
@@ -57,9 +58,11 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@reduce_dataset
@@sample_from_datasets
@@scan
+@@set_stats_aggregator
@@shuffle_and_repeat
@@sliding_window_batch
@@sloppy_interleave
+@@StatsAggregator
@@unbatch
@@unique
@@ -111,6 +114,9 @@ from tensorflow.contrib.data.python.ops.resampling import rejection_resample
from tensorflow.contrib.data.python.ops.scan_ops import scan
from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch
+from tensorflow.contrib.data.python.ops.stats_ops import latency_stats
+from tensorflow.contrib.data.python.ops.stats_ops import set_stats_aggregator
+from tensorflow.contrib.data.python.ops.stats_ops import StatsAggregator
from tensorflow.contrib.data.python.ops.unique import unique
from tensorflow.contrib.data.python.ops.writers import TFRecordWriter
from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
deleted file mode 100644
index ad410e17fe..0000000000
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ /dev/null
@@ -1,284 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/op.h"
-
-namespace tensorflow {
-
-REGISTER_OP("DirectedInterleaveDataset")
- .Input("selector_input_dataset: variant")
- .Input("data_input_datasets: N * variant")
- .Output("handle: variant")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .Attr("N: int >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-A substitute for `InterleaveDataset` on a fixed list of `N` datasets.
-
-selector_input_dataset: A dataset of scalar `DT_INT64` elements that determines
- which of the `N` data inputs should produce the next output element.
-data_input_datasets: `N` datasets with the same type that will be interleaved
- according to the values of `selector_input_dataset`.
-)doc");
-
-REGISTER_OP("CSVDataset")
- .Input("filenames: string")
- .Input("compression_type: string")
- .Input("buffer_size: int64")
- .Input("header: bool")
- .Input("field_delim: string")
- .Input("use_quote_delim: bool")
- .Input("na_value: string")
- .Input("select_cols: int64")
- .Input("record_defaults: output_types")
- .Output("handle: variant")
- .Attr("output_types: list({float,double,int32,int64,string}) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
- // stateful to inhibit constant folding.
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle unused;
- // `filenames` must be a scalar or a vector.
- TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
- // `compression_type`, `buffer_size`, `header`, `field_delim`,
- // `use_quote_delim`, `na_value` must be scalars
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
- // `select_cols` must be a vector
- TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused));
- // `record_defaults` must be lists of scalars
- for (size_t i = 8; i < c->num_inputs(); ++i) {
- shape_inference::ShapeHandle v;
- TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
- if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
- return errors::InvalidArgument(
- "Shape of a default must be a length-0 or length-1 vector, or a "
- "scalar.");
- }
- }
- return shape_inference::ScalarShape(c);
- });
-
-REGISTER_OP("IgnoreErrorsDataset")
- .Input("input_dataset: variant")
- .Output("handle: variant")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that contains the elements of `input_dataset` ignoring errors.
-)doc");
-
-REGISTER_OP("UniqueDataset")
- .Input("input_dataset: variant")
- .Output("handle: variant")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that contains the unique elements of `input_dataset`.
-)doc");
-
-REGISTER_OP("IteratorGetDevice")
- .Input("resource: resource")
- .Output("device: string")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Returns the name of the device on which `resource` has been placed.
-)doc");
-
-REGISTER_OP("FunctionBufferingResource")
- .Input("string_arg: string")
- .Input("target_device: string")
- .Output("resource: resource")
- .Attr("shared_name: string")
- .Attr("container: string")
- .Attr("f: func")
- .Attr("buffer_size: int")
- .Attr("output_types: list(type)")
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Creates a resource that fills up a buffer by making function calls.
-
-string_arg: String argument to the function call.
-target_device: Target device to execute the function on.
-resource: Handle to the resource created.
-f: Function to be executed.
-buffer_size: Size of the buffer.
-container: If non-empty, this resource is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this resource will be shared under the given name
- across multiple sessions.
-output_types: The type list for the return values.
-)doc");
-
-REGISTER_OP("FunctionBufferingResourceGetNext")
- .Input("function_buffer_resource: resource")
- .Attr("output_types: list(type)")
- .Output("output: output_types")
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Gets the next element from a FunctionBufferingResource.
-
-function_buffer_resource: The FunctionBufferingResource handle.
-output: A list of return values.
-output_types: The type list for the return values.
-)doc");
-
-REGISTER_OP("FunctionBufferingResourceReset")
- .Input("function_buffer_resource: resource")
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Resets the FunctionBufferingResource.
-
-function_buffer_resource: The FunctionBufferingResource handle.
-)doc");
-
-REGISTER_OP("MultiDeviceIterator")
- .Output("handle: resource")
- .Attr("devices: list(string) >= 1")
- .Attr("shared_name: string")
- .Attr("container: string")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .Doc(R"doc(
-Creates a MultiDeviceIterator resource.
-
-handle: Handle to the resource created.
-devices: A list of devices the iterator works across.
-shared_name: If non-empty, this resource will be shared under the given name
- across multiple sessions.
-container: If non-empty, this resource is placed in the given container.
- Otherwise, a default container is used.
-output_types: The type list for the return values.
-output_shapes: The list of shapes being produced.
-)doc");
-
-REGISTER_OP("MultiDeviceIteratorInit")
- .Input("dataset: variant")
- .Input("multi_device_iterator: resource")
- .Input("max_buffer_size: int64")
- .Output("incarnation_id: int64")
- .Doc(R"doc(
-Initializes the multi device iterator with the given dataset.
-max_buffer_size: The maximum size of the host side per device buffer to keep.
-incarnation_id: An int64 indicating which incarnation of the MultiDeviceIterator
- is running.
-dataset: Dataset to be iterated upon.
-multi_device_iterator: A MultiDeviceIteratorResource.
-)doc");
-
-REGISTER_OP("MultiDeviceIteratorGetNextFromShard")
- .Input("multi_device_iterator: resource")
- .Input("shard_num: int32")
- .Input("incarnation_id: int64")
- .Output("components: output_types")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .Doc(R"doc(
-Gets next element for the provided shard number.
-
-multi_device_iterator: A MultiDeviceIterator resource.
-shard_num: Integer representing which shard to fetch data for.
-incarnation_id: Which incarnation of the MultiDeviceIterator is running.
-components: Result of the get_next on the dataset.
-output_types: The type list for the return values.
-output_shapes: The list of shapes being produced.
-)doc");
-
-REGISTER_OP("MultiDeviceIteratorToStringHandle")
- .Input("multi_device_iterator: resource")
- .Output("string_handle: string")
- .Doc(R"doc(
-Produces a string handle for the given MultiDeviceIterator.
-
-multi_device_iterator: A MultiDeviceIterator resource.
-string_handle: A string representing the resource.
-)doc");
-
-REGISTER_OP("MultiDeviceIteratorFromStringHandle")
- .Input("string_handle: string")
- .Output("multi_device_iterator: resource")
- .Attr("output_types: list(type) >= 0 = []")
- .Attr("output_shapes: list(shape) >= 0 = []")
- .Doc(R"doc(
-Generates a MultiDeviceIterator resource from its provided string handle.
-
-string_handle: String representing the resource.
-multi_device_iterator: A MultiDeviceIterator resource.
-output_types: The type list for the return values.
-output_shapes: The list of shapes being produced.
-)doc");
-
-REGISTER_OP("ThreadPoolDataset")
- .Input("input_dataset: variant")
- .Input("thread_pool: resource")
- .Output("handle: variant")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that uses a custom thread pool to compute `input_dataset`.
-
-handle: A resource produced by the ThreadPoolHandle op.
-)doc");
-
-REGISTER_OP("ThreadPoolHandle")
- .Output("handle: resource")
- .SetShapeFn(shape_inference::ScalarShape)
- .Attr("num_threads: int")
- .Attr("max_intra_op_parallelism: int = 1")
- .Attr("display_name: string")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Doc(R"doc(
-Creates a custom thread pool with the given number of threads.
-
-handle: A resource that can be consumed by one or more ThreadPoolDataset ops.
-num_threads: The number of threads in the thread pool.
-max_intra_op_parallelism: The maximum degree of parallelism to use within
- operations that execute on this threadpool.
-display_name: A human-readable name for the threads that may be visible in
- some visualizations.
-)doc");
-
-REGISTER_OP("AssertNextDataset")
- .Input("input_dataset: variant")
- .Input("transformations: string")
- .Output("handle: variant")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle unused;
- // transformations should be a vector.
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
- return shape_inference::ScalarShape(c);
- });
-
-REGISTER_OP("LMDBDataset")
- .Input("filenames: string")
- .Output("handle: variant")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
- // stateful to inhibit constant folding.
- .SetShapeFn(shape_inference::ScalarShape);
-
-} // namespace tensorflow
diff --git a/tensorflow/contrib/data/ops/indexed_dataset_ops.cc b/tensorflow/contrib/data/ops/indexed_dataset_ops.cc
deleted file mode 100644
index cd9b7c68a0..0000000000
--- a/tensorflow/contrib/data/ops/indexed_dataset_ops.cc
+++ /dev/null
@@ -1,80 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/op.h"
-
-namespace tensorflow {
-
-REGISTER_OP("IdentityIndexedDataset")
- .Input("size: uint64")
- .Output("handle: variant")
- .SetIsStateful()
- .SetShapeFn(
- shape_inference::ScalarShape); // TODO(saeta): check input shapes.
-
-///////////////////////////////////////////////////////////////////////////////
-// IndexedDataset Internals
-///////////////////////////////////////////////////////////////////////////////
-
-// Creates the handle.
-REGISTER_OP("MaterializedIndexDatasetHandle")
- .Output("handle: resource")
- .Attr("container: string")
- .Attr("shared_name: string")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
-
-// Actually materialize the materialize handle.
-REGISTER_OP("IndexedDatasetMaterialize")
- .Input("dataset: variant")
- .Input("materialized: resource")
- .SetShapeFn(shape_inference::NoOutputs);
-
-namespace {
-
-Status GetShapeFn(shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
- std::vector<PartialTensorShape> output_shapes;
- TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
- if (output_shapes.size() != c->num_outputs()) {
- return errors::InvalidArgument(
- "`output_shapes` must be the same length as `output_types` (",
- output_shapes.size(), " vs. ", c->num_outputs());
- }
- for (size_t i = 0; i < output_shapes.size(); ++i) {
- shape_inference::ShapeHandle output_shape_handle;
- TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
- output_shapes[i], &output_shape_handle));
- c->set_output(static_cast<int>(i), output_shape_handle);
- }
- return Status::OK();
-}
-
-} // namespace
-
-REGISTER_OP("IndexedDatasetGet")
- .Input("materialized: resource")
- .Input("index: uint64")
- .Output("components: output_types")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(GetShapeFn)
- .Doc(R"doc(
-Gets the element at `index` from `materialized` IndexedDataset.
-)doc");
-
-} // namespace tensorflow
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index ba202839b2..21ac40eb21 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -139,12 +139,11 @@ py_test(
name = "indexed_dataset_ops_test",
srcs = ["indexed_dataset_ops_test.py"],
deps = [
- "//tensorflow/contrib/data/python/ops:contrib_op_loader",
- "//tensorflow/contrib/data/python/ops:gen_dataset_ops",
"//tensorflow/contrib/data/python/ops:indexed_dataset_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -190,7 +189,6 @@ py_test(
"//tensorflow/python:training",
"//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/estimator",
"//tensorflow/python/estimator:estimator_py",
],
)
@@ -326,12 +324,7 @@ cuda_py_test(
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
],
- tags = [
- "manual",
- "no_oss",
- "no_windows_gpu",
- "notap",
- ],
+ tags = ["no_windows_gpu"],
)
py_test(
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 8e368bf2bc..e2508de9e9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -742,7 +742,7 @@ class RestructuredDatasetTest(test.TestCase):
iterator = result.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(5):
sess.run(get_next)
@@ -813,7 +813,7 @@ class RestructuredDatasetTest(test.TestCase):
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
@@ -837,7 +837,7 @@ class RestructuredDatasetTest(test.TestCase):
iterator = result.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(5):
sess.run(get_next)
@@ -879,7 +879,7 @@ class RestructuredDatasetTest(test.TestCase):
iterator = result.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(5):
sess.run(get_next)
diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
index 9c508d686d..46a7127b52 100644
--- a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
@@ -19,29 +19,29 @@ from __future__ import print_function
import unittest
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import indexed_dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.platform import test
class IndexedDatasetOpsTest(test.TestCase):
def testLowLevelIndexedDatasetOps(self):
- identity = gen_dataset_ops.identity_indexed_dataset(
+ identity = ged_ops.experimental_identity_indexed_dataset(
ops.convert_to_tensor(16, dtype=dtypes.uint64))
- handle = gen_dataset_ops.materialized_index_dataset_handle(
+ handle = ged_ops.experimental_materialized_index_dataset_handle(
container="",
shared_name="",
output_types=[dtypes.uint64],
output_shapes=[[]])
- materialize = gen_dataset_ops.indexed_dataset_materialize(identity, handle)
+ materialize = ged_ops.experimental_indexed_dataset_materialize(
+ identity, handle)
index = array_ops.placeholder(dtypes.uint64)
- get_op = gen_dataset_ops.indexed_dataset_get(
+ get_op = ged_ops.experimental_indexed_dataset_get(
handle, index, output_types=[dtypes.uint64], output_shapes=[[]])
with self.cached_session() as sess:
diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
index 704c0d1eb2..7e2326bd17 100644
--- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
@@ -42,7 +42,7 @@ class CheckpointInputPipelineHookTest(test.TestCase):
del config
global_step = training_util.get_or_create_global_step()
update_global_step_op = global_step.assign_add(1)
- latest_feature = variables.Variable(
+ latest_feature = variables.VariableV1(
0, name='latest_feature', dtype=dtypes.int64)
store_latest_feature_op = latest_feature.assign(features)
ops.add_to_collection('my_vars', global_step)
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
index 83b723710c..25aea0393f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
@@ -116,7 +116,7 @@ class MapDefunTest(test.TestCase):
elems2 = array_ops.placeholder(dtypes.int32)
result = map_defun.map_defun(fn, [elems1, elems2],
[dtypes.int32, dtypes.int32], [(), ()])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
"All inputs must have the same dimension 0."):
@@ -225,7 +225,7 @@ class MapDefunTest(test.TestCase):
c = constant_op.constant([1, 2, 3, 4, 5])
map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
thread = self.checkedThread(
self._assert_op_cancelled, args=(sess, map_defun_op))
thread.start()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
index b3187bf61b..1ae92bdeff 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
@@ -20,6 +20,23 @@ py_test(
)
py_test(
+ name = "hoist_random_uniform_test",
+ size = "small",
+ srcs = ["hoist_random_uniform_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
name = "latency_all_edges_test",
size = "small",
srcs = ["latency_all_edges_test.py"],
@@ -110,6 +127,22 @@ py_test(
)
py_test(
+ name = "noop_elimination_test",
+ size = "small",
+ srcs = ["noop_elimination_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:batching",
+ "//tensorflow/contrib/data/python/ops:interleave_ops",
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "optimize_dataset_op_test",
size = "small",
srcs = ["optimize_dataset_op_test.py"],
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
index bd7b50b902..d10da80442 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
@@ -31,7 +31,7 @@ class AssertNextDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(0, sess.run(get_next))
def testAssertNextInvalid(self):
@@ -40,7 +40,7 @@ class AssertNextDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Asserted Whoops transformation at offset 0 but encountered "
@@ -53,7 +53,7 @@ class AssertNextDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Asserted next 2 transformations but encountered only 1."):
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py
new file mode 100644
index 0000000000..9518c2e1ad
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py
@@ -0,0 +1,102 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for HostState optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+class HoistRandomUniformTest(test.TestCase, parameterized.TestCase):
+
+ @staticmethod
+ def map_functions():
+ plus_one = lambda x: x + 1
+
+ def random(_):
+ return random_ops.random_uniform([],
+ minval=1,
+ maxval=10,
+ dtype=dtypes.float32,
+ seed=42)
+
+ def random_with_assert(x):
+ y = random(x)
+ assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
+ with ops.control_dependencies([assert_op]):
+ return y
+
+ twice_random = lambda x: (random(x) + random(x)) / 2.
+
+ tests = [("PlusOne", plus_one, False), ("RandomUniform", random, True),
+ ("RandomWithAssert", random_with_assert, True),
+ ("TwiceRandom", twice_random, False)]
+ return tuple(tests)
+
+ @parameterized.named_parameters(*map_functions.__func__())
+ def testHoisting(self, function, will_optimize):
+ dataset = dataset_ops.Dataset.range(5).apply(
+ optimization.assert_next(
+ ["Zip[0]", "Map"] if will_optimize else ["Map"])).map(function)
+
+ dataset = dataset.apply(optimization.optimize(["hoist_random_uniform"]))
+ self._testDataset(dataset)
+
+ def testAdditionalInputs(self):
+ a = constant_op.constant(1, dtype=dtypes.float32)
+ b = constant_op.constant(0, dtype=dtypes.float32)
+ some_tensor = math_ops.mul(a, b)
+
+ def random_with_capture(_):
+ return some_tensor + random_ops.random_uniform(
+ [], minval=1, maxval=10, dtype=dtypes.float32, seed=42)
+
+ dataset = dataset_ops.Dataset.range(5).apply(
+ optimization.assert_next(
+ ["Zip[0]", "Map"])).map(random_with_capture).apply(
+ optimization.optimize(["hoist_random_uniform"]))
+ self._testDataset(dataset)
+
+ def _testDataset(self, dataset):
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ previous_result = 0
+ with self.cached_session() as sess:
+ for _ in range(5):
+ result = sess.run(get_next)
+ self.assertLessEqual(1, result)
+ self.assertLessEqual(result, 10)
+ # This checks if the result is somehow random by checking if we are not
+ # generating the same values.
+ self.assertNotEqual(previous_result, result)
+ previous_result = result
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
index db380c02a9..e4f18222fd 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
@@ -34,8 +34,8 @@ class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
optimization.assert_next(
["LatencyStats", "Map", "LatencyStats", "Prefetch",
"LatencyStats"])).map(lambda x: x * x).prefetch(1).apply(
- optimization.optimize(["latency_all_edges"])).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
+ stats_ops.set_stats_aggregator(stats_aggregator)).apply(
+ optimization.optimize(["latency_all_edges"]))
iterator = dataset.make_initializable_iterator()
get_next = iterator.get_next()
summary_t = stats_aggregator.get_summary()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
index dde115925e..e75edf6086 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -200,7 +200,7 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
optimization.optimize(["filter_fusion"]))
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for x in range(5):
r = map_function(x)
filtered = False
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
index e2c9bc82df..5b493f44c9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
@@ -173,16 +173,6 @@ class MapVectorizationBenchmark(test.Benchmark):
self.report_benchmark(iters=num_iters, wall_time=median_time, name=name)
return median_time
- def benchmark_CheapFns(self):
-
- input_sizes = [(10, 10, 3), (10, 100, 300)]
- batch_size = 1000
- for input_size in input_sizes:
- input_dataset = dataset_ops.Dataset.from_tensor_slices(
- (np.random.rand(*input_size), np.random.rand(*input_size))).repeat()
- for map_fn, str_id in self._get_known_cheap_fns():
- self._compare(input_dataset, map_fn, batch_size, input_size, str_id)
-
def _compare(self, input_dataset, map_fn, batch_size, input_size, str_id):
num_elems = np.prod(input_size)
name_template = "{}__batch_size_{}_input_size_{}_{}"
@@ -205,14 +195,28 @@ class MapVectorizationBenchmark(test.Benchmark):
"Speedup: {}\n".format(batch_size, input_size, str_id,
(unoptimized_time / optimized_time)))
- def _get_known_cheap_fns(self):
- return [
- (lambda *args: [array_ops.identity(x) for x in args], "identity"),
- (lambda *args: [x + 1 for x in args], "add_const"),
- (lambda *args: args[0], "select"),
- (lambda *args: [math_ops.cast(x, dtypes.float64) for x in args],
- "cast"),
- ]
+ # Known cheap functions
+ def benchmarkIdentity(self):
+ self._benchmark_helper(lambda *args: [array_ops.identity(x) for x in args],
+ "identity")
+
+ def benchmarkAddConst(self):
+ self._benchmark_helper(lambda *args: [x + 1 for x in args], "add_const")
+
+ def benchmarkSelect(self):
+ self._benchmark_helper(lambda *args: args[0], "select")
+
+ def benchmarkCast(self):
+ self._benchmark_helper(
+ lambda *args: [math_ops.cast(x, dtypes.float64) for x in args], "cast")
+
+ def _benchmark_helper(self, map_fn, str_id):
+ input_sizes = [(10, 10, 3), (10, 100, 300)]
+ batch_size = 1000
+ for input_size in input_sizes:
+ input_dataset = dataset_ops.Dataset.from_tensor_slices(
+ (np.random.rand(*input_size), np.random.rand(*input_size))).repeat()
+ self._compare(input_dataset, map_fn, batch_size, input_size, str_id)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
index 0a87d3e905..3b62a7e468 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
@@ -40,7 +40,7 @@ class ModelDatasetTest(test.TestCase):
get_next = iterator.get_next()
deltas = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(5):
sess.run(get_next.op)
for _ in range(100):
@@ -58,12 +58,13 @@ class ModelDatasetTest(test.TestCase):
dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
np.random.rand(4 * k,
1))).repeat()
- dataset = dataset.map(math_ops.matmul, num_parallel_calls=56)
+ dataset = dataset.map(
+ math_ops.matmul, num_parallel_calls=optimization.AUTOTUNE)
iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
get_next = iterator.get_next()
deltas = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(5):
sess.run(get_next.op)
for _ in range(1000):
@@ -84,12 +85,14 @@ class ModelDatasetTest(test.TestCase):
1))).repeat()
dataset = dataset.apply(
batching.map_and_batch(
- math_ops.matmul, num_parallel_calls=28, batch_size=batch_size))
+ math_ops.matmul,
+ num_parallel_calls=optimization.AUTOTUNE,
+ batch_size=batch_size))
iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
get_next = iterator.get_next()
deltas = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(5):
sess.run(get_next.op)
for _ in range(10):
@@ -109,12 +112,14 @@ class ModelDatasetTest(test.TestCase):
1))).repeat()
dataset = dataset.map(math_ops.matmul)
dataset = dataset_ops.Dataset.range(1).repeat().interleave(
- lambda _: dataset, cycle_length=56, num_parallel_calls=56)
+ lambda _: dataset,
+ cycle_length=10,
+ num_parallel_calls=optimization.AUTOTUNE)
iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
get_next = iterator.get_next()
deltas = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(5):
sess.run(get_next.op)
for _ in range(1000):
@@ -146,20 +151,20 @@ class ModelDatasetTest(test.TestCase):
x, y = c
return a, b, math_ops.matmul(x, y)
- dataset = dataset.map(f1, num_parallel_calls=32)
+ dataset = dataset.map(f1, num_parallel_calls=optimization.AUTOTUNE)
dataset = dataset_ops.Dataset.range(1).repeat().interleave(
lambda _: dataset, cycle_length=2)
- dataset = dataset.map(f2, num_parallel_calls=16)
+ dataset = dataset.map(f2, num_parallel_calls=optimization.AUTOTUNE)
dataset = dataset_ops.Dataset.range(1).repeat().interleave(
lambda _: dataset, cycle_length=2)
- dataset = dataset.map(f3, num_parallel_calls=10)
+ dataset = dataset.map(f3, num_parallel_calls=optimization.AUTOTUNE)
iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
get_next = iterator.get_next()
deltas = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(5):
sess.run(get_next)
for _ in range(100):
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
new file mode 100644
index 0000000000..507feda3ad
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
@@ -0,0 +1,57 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapParallelization optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class NoopEliminationTest(test.TestCase):
+
+ def testNoopElimination(self):
+ a = constant_op.constant(1, dtype=dtypes.int64)
+ b = constant_op.constant(2, dtype=dtypes.int64)
+ some_tensor = math_ops.mul(a, b)
+
+ dataset = dataset_ops.Dataset.range(5)
+ dataset = dataset.apply(
+ optimization.assert_next(
+ ["FiniteRepeat", "FiniteSkip", "Prefetch", "Prefetch"]))
+ dataset = dataset.repeat(some_tensor).skip(5).prefetch(0).take(-1).skip(
+ 0).repeat(1).prefetch(0)
+ dataset = dataset.apply(optimization.optimize(["noop_elimination"]))
+
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ for x in range(5):
+ result = sess.run(get_next)
+ self.assertAllEqual(result, x)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
index 909da5aee0..a3fb824ce9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
@@ -38,7 +38,7 @@ class OptimizeDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -51,7 +51,7 @@ class OptimizeDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -64,7 +64,7 @@ class OptimizeDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -76,7 +76,7 @@ class OptimizeDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(get_next)
def testOptimizationLargeInputFromTensor(self):
@@ -87,7 +87,7 @@ class OptimizeDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
sess.run(get_next)
@@ -99,7 +99,7 @@ class OptimizeDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
sess.run(get_next)
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index 0166ba0d44..33a64ea767 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -31,7 +31,6 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
@@ -944,155 +943,5 @@ class CopyToDeviceTest(test.TestCase):
sess.run(elem_value_t)
-class MultiDeviceIteratorTest(test.TestCase):
-
- def testBasic(self):
- dataset = dataset_ops.Dataset.range(10)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/cpu:2"])
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
-
- config = config_pb2.ConfigProto(device_count={"CPU": 3})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 10, 2):
- self.assertEqual(i, sess.run(elem_on_1))
- self.assertEqual(i + 1, sess.run(elem_on_2))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
-
- def testOneOnSameDevice(self):
- with ops.device("/cpu:0"):
- dataset = dataset_ops.Dataset.range(10)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:0", "/cpu:1"])
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
-
- config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 10, 2):
- self.assertEqual(i, sess.run(elem_on_1))
- self.assertEqual(i + 1, sess.run(elem_on_2))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
-
- def testRepeatDevices(self):
- with ops.device("/cpu:0"):
- dataset = dataset_ops.Dataset.range(20)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/cpu:2", "/cpu:1", "/cpu:2"])
- elements = multi_device_iterator.get_next()
- elem_on_1, elem_on_2, elem_on_3, elem_on_4 = elements
-
- config = config_pb2.ConfigProto(device_count={"CPU": 3})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 20, 4):
- self.assertEqual(i, sess.run(elem_on_1))
- self.assertEqual(i + 1, sess.run(elem_on_2))
- self.assertEqual(i + 2, sess.run(elem_on_3))
- self.assertEqual(i + 3, sess.run(elem_on_4))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
- sess.run(elem_on_3)
- sess.run(elem_on_4)
-
- def testNotFullyDivisible(self):
- dataset = dataset_ops.Dataset.range(9)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/cpu:2"])
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
-
- config = config_pb2.ConfigProto(device_count={"CPU": 3})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 8, 2):
- self.assertEqual(i, sess.run(elem_on_1))
- self.assertEqual(i + 1, sess.run(elem_on_2))
- self.assertEqual(8, sess.run(elem_on_1))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
-
- def testUneven(self):
- dataset = dataset_ops.Dataset.range(10)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/cpu:2"], max_buffer_size=4)
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
-
- config = config_pb2.ConfigProto(device_count={"CPU": 3})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 10, 2):
- self.assertEqual(i, sess.run(elem_on_1))
- for i in range(0, 10, 2):
- self.assertEqual(i + 1, sess.run(elem_on_2))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
-
- def testMultipleInitializations(self):
- with ops.device("/cpu:0"):
- epoch = array_ops.placeholder(dtypes.int64, shape=[])
- dataset1 = dataset_ops.Dataset.from_tensors(epoch).repeat(1000)
- dataset2 = dataset_ops.Dataset.range(1000)
- dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4)
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
- init_op = multi_device_iterator.initializer
-
- config = config_pb2.ConfigProto(device_count={"CPU": 3})
- with self.test_session(config=config) as sess:
- for i in range(1000):
- sess.run(init_op, feed_dict={epoch: i})
- self.assertEqual([(i, 0), (i, 1)], sess.run([elem_on_1, elem_on_2]))
-
- def testBasicGpu(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- with compat.forward_compatibility_horizon(2018, 8, 4):
- dataset = dataset_ops.Dataset.range(10)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/gpu:0"])
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
-
- config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 10, 2):
- self.assertEqual(i, sess.run(elem_on_1))
- self.assertEqual(i + 1, sess.run(elem_on_2))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
-
- def testUnevenGpu(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- with compat.forward_compatibility_horizon(2018, 8, 4):
- dataset = dataset_ops.Dataset.range(10)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/gpu:0"], max_buffer_size=4)
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
-
- config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 10, 2):
- self.assertEqual(i, sess.run(elem_on_1))
- for i in range(0, 10, 2):
- self.assertEqual(i + 1, sess.run(elem_on_2))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
index 14cd3e9c4a..a10f85263a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -90,6 +91,16 @@ class StatsDatasetSerializationTest(
lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2),
None, num_outputs)
+ def _build_dataset_stats_aggregator(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ return dataset_ops.Dataset.range(10).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+
+ def test_set_stats_aggregator_not_support_checkpointing(self):
+ with self.assertRaisesRegexp(errors.UnimplementedError,
+ "does not support checkpointing"):
+ self.run_core_tests(self._build_dataset_stats_aggregator, None, 10)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
index e25570c5ad..be8ae5e955 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
@@ -25,6 +25,7 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@@ -40,7 +41,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
expected_sum = 0.0
for i in range(100):
@@ -65,7 +66,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(100):
self.assertEqual(i, sess.run(next_element))
@@ -84,7 +85,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(100):
self.assertAllEqual(
@@ -92,6 +93,8 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
summary_str = sess.run(summary_t)
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
float(i + 1))
+ self._assertSummaryContains(summary_str, "Prefetch::buffer_capacity")
+ self._assertSummaryContains(summary_str, "Prefetch::buffer_size")
self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization",
0, 1)
with self.assertRaises(errors.OutOfRangeError):
@@ -100,6 +103,53 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
100)
+ def testPrefetchBufferScalars(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(10).map(
+ lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(
+ 0).apply(stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertAllEqual(
+ np.array([i] * i, dtype=np.int64), sess.run(next_element))
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasScalarValue(summary_str,
+ "Prefetch::buffer_capacity", 0)
+ self._assertSummaryHasScalarValue(summary_str, "Prefetch::buffer_size",
+ 0)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testFilteredElementsStats(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(101).filter(
+ lambda x: math_ops.equal(math_ops.mod(x, 3), 0)).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(34):
+ self.assertEqual(i * 3, sess.run(next_element))
+ if i is not 0:
+ self._assertSummaryHasScalarValue(
+ sess.run(summary_t), "Filter::dropped_elements", float(i * 2))
+ self._assertSummaryHasScalarValue(
+ sess.run(summary_t), "Filter::filtered_elements", float(i + 1))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasScalarValue(
+ sess.run(summary_t), "Filter::dropped_elements", 67.0)
+ self._assertSummaryHasScalarValue(
+ sess.run(summary_t), "Filter::filtered_elements", 34.0)
+
def testReinitialize(self):
stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
@@ -109,7 +159,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for j in range(5):
sess.run(iterator.initializer)
for i in range(100):
@@ -127,7 +177,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(100):
self.assertEqual(i, sess.run(next_element))
@@ -144,7 +194,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(100):
self.assertEqual(i, sess.run(next_element))
@@ -168,7 +218,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(100):
self.assertEqual(i, sess.run(next_element))
@@ -188,7 +238,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
next_element = iterator_0.get_next() + iterator_1.get_next()
summary_t = stats_aggregator.get_summary()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([iterator_0.initializer, iterator_1.initializer])
for i in range(100):
self.assertEqual(i * 2, sess.run(next_element))
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
index 2f5a44408f..b1b4c23510 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
@@ -25,6 +25,14 @@ from tensorflow.python.platform import test
class StatsDatasetTestBase(test.TestCase):
"""Base class for testing statistics gathered in `StatsAggregator`."""
+ def _assertSummaryContains(self, summary_str, tag):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
def _assertSummaryHasCount(self, summary_str, tag, expected_value):
summary_proto = summary_pb2.Summary()
summary_proto.ParseFromString(summary_str)
@@ -52,3 +60,12 @@ class StatsDatasetTestBase(test.TestCase):
self.assertEqual(expected_value, value.histo.sum)
return
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
+ def _assertSummaryHasScalarValue(self, summary_str, tag, expected_value):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ self.assertEqual(expected_value, value.simple_value)
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
index 6eaa0b1959..8b7b3ac0f7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
@@ -89,13 +89,14 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
return dataset_ops.Dataset.zip(
tuple([fn(*arg) if isinstance(arg, tuple) else arg for arg in args]))
- dataset = self._structuredDataset(structure, shape, dtype).apply(
+ dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply(
grouping.window_dataset(5)).flat_map(fn)
get_next = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
expected = sess.run(self._structuredElement(structure, shape, dtype))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
+ for _ in range(5):
+ actual = sess.run(get_next)
+ self._assertEqual(expected, actual)
@parameterized.named_parameters(
("1", None, np.int32([]), dtypes.bool),
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index a14781cd93..5cd1ed542b 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -78,7 +78,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":batching",
- ":gen_dataset_ops",
":interleave_ops",
":optimization",
":parsing_ops",
@@ -86,6 +85,7 @@ py_library(
"//tensorflow/python:constant_op",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lib",
"//tensorflow/python:platform",
@@ -148,8 +148,7 @@ py_library(
srcs = ["error_ops.py"],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
@@ -179,12 +178,11 @@ py_library(
srcs = ["interleave_ops.py"],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
":random_ops",
"//tensorflow/contrib/stateless",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
@@ -199,9 +197,8 @@ py_library(
srcs = ["optimization.py"],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
@@ -304,8 +301,7 @@ py_library(
srcs = ["threadpool.py"],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
@@ -321,9 +317,8 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
@@ -342,47 +337,11 @@ py_library(
],
)
-tf_gen_op_wrapper_py(
- name = "gen_dataset_ops",
- out = "gen_dataset_ops.py",
- deps = [
- "//tensorflow/contrib/data:dataset_ops_op_lib",
- "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
- ],
-)
-
-tf_kernel_library(
- name = "dataset_ops_kernels",
- deps = [
- "//tensorflow/contrib/data/kernels:dataset_kernels",
- "//tensorflow/core:framework",
- ],
- alwayslink = 1,
-)
-
-tf_custom_op_py_library(
- name = "contrib_op_loader",
- srcs = ["contrib_op_loader.py"],
- dso = ["//tensorflow/contrib/data:_dataset_ops.so"],
- kernels = [
- ":dataset_ops_kernels",
- "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
- "//tensorflow/contrib/data:dataset_ops_op_lib",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":gen_dataset_ops",
- "//tensorflow/contrib/util:util_py",
- "//tensorflow/python:platform",
- ],
-)
-
py_library(
name = "indexed_dataset_ops",
srcs = ["indexed_dataset_ops.py"],
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
@@ -394,7 +353,7 @@ py_library(
name = "prefetching_ops",
srcs = ["prefetching_ops.py"],
deps = [
- ":contrib_op_loader",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 367c159dc5..7a0f221284 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -345,12 +345,12 @@ def _padded_batch_sparse_window(dataset, padded_shape):
dataset.apply(grouping.group_by_reducer(key_fn, reducer)))
-class _UnbatchDataset(dataset_ops.Dataset):
+class _UnbatchDataset(dataset_ops.UnaryDataset):
"""A dataset that splits the elements of its input into multiple elements."""
def __init__(self, input_dataset):
"""See `unbatch()` for more details."""
- super(_UnbatchDataset, self).__init__()
+ super(_UnbatchDataset, self).__init__(input_dataset)
flat_shapes = nest.flatten(input_dataset.output_shapes)
if any(s.ndims == 0 for s in flat_shapes):
raise ValueError("Cannot unbatch an input with scalar components.")
@@ -514,12 +514,12 @@ def padded_batch_and_drop_remainder(batch_size,
return _apply_fn
-class _DenseToSparseBatchDataset(dataset_ops.Dataset):
+class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s."""
def __init__(self, input_dataset, batch_size, row_shape):
"""See `Dataset.dense_to_sparse_batch()` for more details."""
- super(_DenseToSparseBatchDataset, self).__init__()
+ super(_DenseToSparseBatchDataset, self).__init__(input_dataset)
if not isinstance(input_dataset.output_types, dtypes.DType):
raise TypeError("DenseToSparseDataset requires an input whose elements "
"have a single component, whereas the input has %r." %
@@ -548,7 +548,7 @@ class _DenseToSparseBatchDataset(dataset_ops.Dataset):
return self._input_dataset.output_types
-class _RestructuredDataset(dataset_ops.Dataset):
+class _RestructuredDataset(dataset_ops.UnaryDataset):
"""An internal helper for changing the structure and shape of a dataset."""
def __init__(self,
@@ -583,7 +583,7 @@ class _RestructuredDataset(dataset_ops.Dataset):
ValueError: If either `output_types` or `output_shapes` is not compatible
with the structure of `dataset`.
"""
- super(_RestructuredDataset, self).__init__()
+ super(_RestructuredDataset, self).__init__(dataset)
self._input_dataset = dataset
if not allow_unsafe_cast:
diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py
index b4a7521e08..f962e623ee 100644
--- a/tensorflow/contrib/data/python/ops/error_ops.py
+++ b/tensorflow/contrib/data/python/ops/error_ops.py
@@ -17,9 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
def ignore_errors():
@@ -51,16 +50,16 @@ def ignore_errors():
return _apply_fn
-class _IgnoreErrorsDataset(dataset_ops.Dataset):
+class _IgnoreErrorsDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that silently ignores errors when computing its input."""
def __init__(self, input_dataset):
"""See `Dataset.ignore_errors()` for details."""
- super(_IgnoreErrorsDataset, self).__init__()
+ super(_IgnoreErrorsDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
def _as_variant_tensor(self):
- return gen_dataset_ops.ignore_errors_dataset(
+ return gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
**dataset_ops.flat_structure(self))
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 099e10db92..7cae33beb3 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -255,6 +255,7 @@ def _map_x_dataset(map_func):
return _apply_fn
+# TODO(b/115382007) Remove this once canned reducers move to core.
def window_dataset(window_size):
"""A transformation that creates window datasets from the input dataset.
@@ -271,17 +272,22 @@ def window_dataset(window_size):
"""
def _apply_fn(dataset):
- return _WindowDataset(dataset, window_size)
+ return dataset_ops.WindowDataset(
+ dataset,
+ size=window_size,
+ shift=window_size,
+ stride=1,
+ drop_remainder=False)
return _apply_fn
-class _GroupByReducerDataset(dataset_ops.Dataset):
+class _GroupByReducerDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that groups its input and performs a reduction."""
def __init__(self, input_dataset, key_func, reducer):
"""See `group_by_reducer()` for details."""
- super(_GroupByReducerDataset, self).__init__()
+ super(_GroupByReducerDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
@@ -410,12 +416,12 @@ class _GroupByReducerDataset(dataset_ops.Dataset):
**dataset_ops.flat_structure(self))
-class _GroupByWindowDataset(dataset_ops.Dataset):
+class _GroupByWindowDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that groups its input and performs a windowed reduction."""
def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
"""See `group_by_window()` for details."""
- super(_GroupByWindowDataset, self).__init__()
+ super(_GroupByWindowDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
@@ -519,12 +525,12 @@ class Reducer(object):
return self._finalize_func
-class _MapXDataset(dataset_ops.Dataset):
+class _MapXDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that maps a function over elements in its input."""
def __init__(self, input_dataset, map_func):
"""See `map_x_dataset()` for details."""
- super(_MapXDataset, self).__init__()
+ super(_MapXDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
wrapped_func = dataset_ops.StructuredFunctionWrapper(
@@ -556,46 +562,3 @@ class _MapXDataset(dataset_ops.Dataset):
@property
def output_types(self):
return self._output_types
-
-
-class _WindowDataset(dataset_ops.Dataset):
- """A dataset that creates window datasets from the input elements."""
-
- def __init__(self, input_dataset, window_size):
- """See `window_dataset()` for more details."""
- super(_WindowDataset, self).__init__()
- self._input_dataset = input_dataset
- self._window_size = ops.convert_to_tensor(
- window_size, dtype=dtypes.int64, name="window_size")
- self._output_classes = nest.pack_sequence_as(
- input_dataset.output_classes,
- [
- dataset_ops._NestedDatasetComponent( # pylint: disable=protected-access
- output_classes=output_class,
- output_shapes=output_shape,
- output_types=output_type)
- for output_class, output_shape, output_type in zip(
- nest.flatten(input_dataset.output_classes),
- nest.flatten(input_dataset.output_shapes),
- nest.flatten(input_dataset.output_types))
- ])
- self._output_shapes = self._output_classes
- self._output_types = self._output_classes
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.window_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._window_size,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
diff --git a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
index a0932b4081..9c06474a2f 100644
--- a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
+++ b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
@@ -19,14 +19,13 @@ from __future__ import print_function
import abc
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
class MaterializedIndexedDataset(object):
@@ -57,7 +56,7 @@ class MaterializedIndexedDataset(object):
A tensor containing the values corresponding to `index`.
"""
# TODO(saeta): nest.pack_sequence_as(...)
- return gen_dataset_ops.indexed_dataset_get(
+ return ged_ops.experimental_indexed_dataset_get(
self._materialized_resource,
index,
output_types=nest.flatten(
@@ -90,16 +89,18 @@ class IndexedDataset(dataset_ops.Dataset):
container = ""
if shared_name is None:
shared_name = ""
- materialized_resource = gen_dataset_ops.materialized_index_dataset_handle(
- container=container,
- shared_name=shared_name,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_types(self.output_shapes, self.output_classes)))
+ materialized_resource = (
+ ged_ops.experimental_materialized_index_dataset_handle(
+ container=container,
+ shared_name=shared_name,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_types(self.output_shapes,
+ self.output_classes))))
with ops.colocate_with(materialized_resource):
- materializer = gen_dataset_ops.indexed_dataset_materialize(
+ materializer = ged_ops.experimental_indexed_dataset_materialize(
self._as_variant_tensor(), materialized_resource)
return MaterializedIndexedDataset(materialized_resource, materializer,
self.output_classes, self.output_types,
@@ -170,4 +171,7 @@ class IdentityIndexedDataset(IndexedDataset):
return tensor_shape.scalar()
def _as_variant_tensor(self):
- return gen_dataset_ops.identity_indexed_dataset(self._size)
+ return ged_ops.experimental_identity_indexed_dataset(self._size)
+
+ def _inputs(self):
+ return []
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index 92d4251a86..1ee9db1aa8 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -18,8 +18,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib import stateless
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import random_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
@@ -28,6 +26,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import deprecation
@@ -167,12 +166,17 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
def _as_variant_tensor(self):
# pylint: disable=protected-access
- return gen_dataset_ops.directed_interleave_dataset(
- self._selector_input._as_variant_tensor(),
- [data_input._as_variant_tensor() for data_input in self._data_inputs],
- **dataset_ops.flat_structure(self))
+ return (
+ gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
+ self._selector_input._as_variant_tensor(), [
+ data_input._as_variant_tensor()
+ for data_input in self._data_inputs
+ ], **dataset_ops.flat_structure(self)))
# pylint: enable=protected-access
+ def _inputs(self):
+ return [self._selector_input] + self._data_inputs
+
@property
def output_classes(self):
return self._data_inputs[0].output_classes
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
index 73840452df..30348ede36 100644
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ b/tensorflow/contrib/data/python/ops/optimization.py
@@ -17,12 +17,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
# A constant that can be used to enable auto-tuning.
AUTOTUNE = -1
@@ -54,7 +53,7 @@ def model():
Returns:
A `Dataset` transformation function, which can be passed to
- @{tf.data.Dataset.apply}.
+ `tf.data.Dataset.apply`.
"""
def _apply_fn(dataset):
@@ -84,12 +83,12 @@ def optimize(optimizations=None):
return _apply_fn
-class _AssertNextDataset(dataset_ops.Dataset):
+class _AssertNextDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that asserts which transformations happen next."""
def __init__(self, input_dataset, transformations):
"""See `assert_next()` for details."""
- super(_AssertNextDataset, self).__init__()
+ super(_AssertNextDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if transformations is None:
raise ValueError("At least one transformation should be specified")
@@ -97,7 +96,7 @@ class _AssertNextDataset(dataset_ops.Dataset):
transformations, dtype=dtypes.string, name="transformations")
def _as_variant_tensor(self):
- return contrib_gen_dataset_ops.assert_next_dataset(
+ return gen_experimental_dataset_ops.experimental_assert_next_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._transformations,
**dataset_ops.flat_structure(self))
@@ -115,12 +114,12 @@ class _AssertNextDataset(dataset_ops.Dataset):
return self._input_dataset.output_types
-class _ModelDataset(dataset_ops.Dataset):
+class _ModelDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and models performance."""
def __init__(self, input_dataset):
"""See `optimize()` for details."""
- super(_ModelDataset, self).__init__()
+ super(_ModelDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
def _as_variant_tensor(self):
@@ -141,12 +140,12 @@ class _ModelDataset(dataset_ops.Dataset):
return self._input_dataset.output_types
-class _OptimizeDataset(dataset_ops.Dataset):
+class _OptimizeDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and applies optimizations."""
def __init__(self, input_dataset, optimizations):
"""See `optimize()` for details."""
- super(_OptimizeDataset, self).__init__()
+ super(_OptimizeDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if optimizations is None:
optimizations = []
diff --git a/tensorflow/contrib/data/python/ops/parsing_ops.py b/tensorflow/contrib/data/python/ops/parsing_ops.py
index 2701605e64..cfbba701b0 100644
--- a/tensorflow/contrib/data/python/ops/parsing_ops.py
+++ b/tensorflow/contrib/data/python/ops/parsing_ops.py
@@ -26,11 +26,11 @@ from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import parsing_ops
-class _ParseExampleDataset(dataset_ops.Dataset):
+class _ParseExampleDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that parses `example` dataset into a `dict` dataset."""
def __init__(self, input_dataset, features, num_parallel_calls):
- super(_ParseExampleDataset, self).__init__()
+ super(_ParseExampleDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if not all(types == dtypes.string
for types in nest.flatten(input_dataset.output_types)):
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py
index 5222011d04..46f82e453a 100644
--- a/tensorflow/contrib/data/python/ops/prefetching_ops.py
+++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py
@@ -19,8 +19,6 @@ from __future__ import print_function
import warnings
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
@@ -31,9 +29,9 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
-from tensorflow.python.ops import gen_dataset_ops as core_gen_dataset_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.ops import resource_variable_ops
@@ -65,7 +63,7 @@ def function_buffering_resource(string_arg,
"""
if shared_name is None:
shared_name = ""
- return gen_dataset_ops.function_buffering_resource(
+ return ged_ops.experimental_function_buffering_resource(
string_arg=string_arg,
target_device=target_device,
shared_name=shared_name,
@@ -79,14 +77,14 @@ def function_buffering_resource(string_arg,
def function_buffering_resource_get_next(function_buffer_resource,
output_types,
name=None):
- return gen_dataset_ops.function_buffering_resource_get_next(
+ return ged_ops.experimental_function_buffering_resource_get_next(
function_buffer_resource=function_buffer_resource,
output_types=output_types,
name=name)
def function_buffering_resource_reset(function_buffer_resource, name=None):
- return gen_dataset_ops.function_buffering_resource_reset(
+ return ged_ops.experimental_function_buffering_resource_reset(
function_buffer_resource=function_buffer_resource, name=name)
@@ -137,7 +135,7 @@ class _PrefetchToDeviceIterator(object):
ret = remote_iterator.get_next()
return nest.flatten(sparse.serialize_sparse_tensors(ret))
- iterator_device = gen_dataset_ops.iterator_get_device(
+ iterator_device = ged_ops.experimental_iterator_get_device(
self._input_iterator._iterator_resource)
with ops.device(device):
@@ -163,10 +161,11 @@ class _PrefetchToDeviceIterator(object):
if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
- flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
+ flat_ret = ged_ops.experimental_function_buffering_resource_get_next(
self._buffering_resource,
- output_types=nest.flatten(sparse.as_dense_types(
- self.output_types, self.output_classes)), name=name)
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ name=name)
ret = sparse.deserialize_sparse_tensors(
nest.pack_sequence_as(self.output_types, flat_ret),
@@ -220,7 +219,7 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
buffer_size):
with ops.device("/device:CPU:0"):
super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset)
- input_iterator_handle = core_gen_dataset_ops.iterator_to_string_handle(
+ input_iterator_handle = gen_dataset_ops.iterator_to_string_handle(
self._resource)
self._device = device
@@ -239,7 +238,8 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
self._buffering_resource = function_buffering_resource(
f=_prefetch_fn,
output_types=self._flat_output_types,
- target_device=gen_dataset_ops.iterator_get_device(self._resource),
+ target_device=ged_ops.experimental_iterator_get_device(
+ self._resource),
string_arg=input_iterator_handle,
buffer_size=buffer_size,
shared_name=iterator_ops._generate_shared_name(
@@ -253,7 +253,7 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
# TODO(b/77291417): Fix
with context.execution_mode(context.SYNC):
with ops.device(self._device):
- ret = gen_dataset_ops.function_buffering_resource_get_next(
+ ret = ged_ops.experimental_function_buffering_resource_get_next(
function_buffer_resource=self._buffering_resource,
output_types=self._flat_output_types)
return sparse.deserialize_sparse_tensors(
@@ -262,10 +262,11 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
# pylint: enable=protected-access
-class _PrefetchToDeviceDataset(dataset_ops.Dataset):
+class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
"""A `Dataset` whose iterator prefetches elements to another device."""
def __init__(self, input_dataset, device, buffer_size):
+ super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._device = device
self._buffer_size = buffer_size if buffer_size is not None else 1
@@ -374,7 +375,7 @@ def copy_to_device(target_device, source_device="/cpu:0"):
# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
# all inputs to the Op are in host memory, thereby avoiding some unnecessary
# Sends and Recvs.
-class _CopyToDeviceDataset(dataset_ops.Dataset):
+class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that copies elements to another device."""
def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
@@ -385,6 +386,7 @@ class _CopyToDeviceDataset(dataset_ops.Dataset):
target_device: The name of the device to which elements would be copied.
source_device: Device where input_dataset would be placed.
"""
+ super(_CopyToDeviceDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._target_device = target_device
spec = framework_device.DeviceSpec().from_string(self._target_device)
@@ -408,12 +410,12 @@ class _CopyToDeviceDataset(dataset_ops.Dataset):
"""
# pylint: disable=protected-access
ds_variant = self._input_dataset._as_variant_tensor()
- resource = core_gen_dataset_ops.anonymous_iterator(
+ resource = gen_dataset_ops.anonymous_iterator(
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
with ops.control_dependencies(
- [core_gen_dataset_ops.make_iterator(ds_variant, resource)]):
- return core_gen_dataset_ops.iterator_to_string_handle(resource)
+ [gen_dataset_ops.make_iterator(ds_variant, resource)]):
+ return gen_dataset_ops.iterator_to_string_handle(resource)
@function.Defun()
def _remote_init_func():
@@ -462,7 +464,7 @@ class _CopyToDeviceDataset(dataset_ops.Dataset):
Returns:
Tensor constant 0
"""
- iterator_resource = core_gen_dataset_ops.iterator_from_string_handle_v2(
+ iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
string_handle,
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
@@ -503,7 +505,7 @@ class _CopyToDeviceDataset(dataset_ops.Dataset):
def _as_variant_tensor(self):
with ops.device(self._target_device):
- return core_gen_dataset_ops.generator_dataset(
+ return gen_dataset_ops.generator_dataset(
self._init_captured_args,
self._next_captured_args,
self._finalize_captured_args,
@@ -524,187 +526,3 @@ class _CopyToDeviceDataset(dataset_ops.Dataset):
@property
def output_classes(self):
return self._input_dataset.output_classes
-
-
-class _PerDeviceGenerator(dataset_ops.Dataset):
- """A `dummy` generator dataset."""
-
- def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
- source_device, target_device, output_shapes, output_types,
- output_classes):
- self._target_device = target_device
- self._output_types = output_types
- self._output_shapes = output_shapes
- self._output_classes = output_classes
- self._flat_output_shapes = nest.flatten(
- sparse.as_dense_shapes(self._output_shapes, self._output_classes))
- self._flat_output_types = nest.flatten(
- sparse.as_dense_types(self._output_types, self._output_classes))
-
- multi_device_iterator_string_handle = (
- gen_dataset_ops.multi_device_iterator_to_string_handle(
- multi_device_iterator_resource))
-
- @function.Defun()
- def _init_func():
- return multi_device_iterator_string_handle
-
- @function.Defun()
- def _remote_init_func():
- return functional_ops.remote_call(
- target=source_device,
- args=_init_func.captured_inputs,
- Tout=[dtypes.string],
- f=_init_func)
-
- self._init_func = _remote_init_func
- self._init_captured_args = _remote_init_func.captured_inputs
-
- @function.Defun(dtypes.string)
- def _next_func(string_handle):
- multi_device_iterator = (
- gen_dataset_ops.multi_device_iterator_from_string_handle(
- string_handle=string_handle,
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes))
- return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
- multi_device_iterator=multi_device_iterator,
- shard_num=shard_num,
- incarnation_id=incarnation_id,
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
-
- @function.Defun(dtypes.string)
- def _remote_next_func(string_handle):
- return functional_ops.remote_call(
- target=source_device,
- args=[string_handle] + _next_func.captured_inputs,
- Tout=self._flat_output_types,
- f=_next_func)
-
- self._next_func = _remote_next_func
- self._next_captured_args = _remote_next_func.captured_inputs
-
- @function.Defun(dtypes.string)
- def _finalize_func(unused_string_handle):
- return array_ops.constant(0, dtypes.int64)
-
- @function.Defun(dtypes.string)
- def _remote_finalize_func(string_handle):
- return functional_ops.remote_call(
- target=source_device,
- args=[string_handle] + _finalize_func.captured_inputs,
- Tout=[dtypes.int64],
- f=_finalize_func)
-
- self._finalize_func = _remote_finalize_func
- self._finalize_captured_args = _remote_finalize_func.captured_inputs
-
- def _as_variant_tensor(self):
- with ops.device(self._target_device):
- return core_gen_dataset_ops.generator_dataset(
- self._init_captured_args,
- self._next_captured_args,
- self._finalize_captured_args,
- init_func=self._init_func,
- next_func=self._next_func,
- finalize_func=self._finalize_func,
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
-
- @property
- def output_types(self):
- return self._output_types
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_classes(self):
- return self._output_classes
-
-
-class MultiDeviceIterator(object):
- """An iterator over multiple devices."""
-
- def __init__(self,
- dataset,
- devices,
- max_buffer_size=1,
- prefetch_buffer_size=1,
- source_device="/cpu:0"):
- """Constructs a MultiDeviceIterator.
-
- Args:
- dataset: The input dataset to be iterated over.
- devices: The list of devices to fetch data to.
- max_buffer_size: Maximum size of the host side per device buffer to keep.
- prefetch_buffer_size: if > 1, then we setup a buffer on each device
- to prefetch into.
- source_device: The host device to place the `dataset` on.
- """
- self._dataset = dataset
- self._devices = devices
- self._source_device = source_device
- self._source_device_tensor = ops.convert_to_tensor(source_device)
-
- self._flat_output_shapes = nest.flatten(
- sparse.as_dense_shapes(self._dataset.output_shapes,
- self._dataset.output_classes))
- self._flat_output_types = nest.flatten(
- sparse.as_dense_types(self._dataset.output_types,
- self._dataset.output_classes))
-
- # Create the MultiDeviceIterator.
- with ops.device(self._source_device):
- self._multi_device_iterator_resource = (
- gen_dataset_ops.multi_device_iterator(
- devices=self._devices,
- shared_name="",
- container="",
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes))
-
- # The incarnation ID is used to ensure consistency between the per-device
- # iterators and the multi-device iterator.
- self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
- self._dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._multi_device_iterator_resource,
- max_buffer_size=max_buffer_size)
-
- # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
- # initialize the device side of the pipeline. This would allow the
- # MultiDeviceIterator to choose, for example, to move some transformations
- # into the device side from its input. It might be useful in rewriting.
- # Create the per device iterators.
- self._device_iterators = []
- i = 0
- for device in self._devices:
- ds = _PerDeviceGenerator(
- i, self._multi_device_iterator_resource, self._incarnation_id,
- self._source_device_tensor, device, self._dataset.output_shapes,
- self._dataset.output_types, self._dataset.output_classes)
- if prefetch_buffer_size > 0:
- ds = ds.prefetch(prefetch_buffer_size)
- with ops.device(device):
- self._device_iterators.append(ds.make_initializable_iterator())
- i += 1
-
- device_iterator_initializers = [
- iterator.initializer for iterator in self._device_iterators
- ]
- self._initializer = control_flow_ops.group(*device_iterator_initializers)
-
- def get_next(self):
- result = []
- i = 0
- for device in self._devices:
- with ops.device(device):
- result.append(self._device_iterators[i].get_next())
- i += 1
- return result
-
- @property
- def initializer(self):
- return self._initializer
diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py
index e670c4c835..344a0763c8 100644
--- a/tensorflow/contrib/data/python/ops/random_ops.py
+++ b/tensorflow/contrib/data/python/ops/random_ops.py
@@ -25,7 +25,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_dataset_ops
-class RandomDataset(dataset_ops.Dataset):
+class RandomDataset(dataset_ops.DatasetSource):
"""A `Dataset` of pseudorandom values."""
def __init__(self, seed=None):
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 785b395707..360971e200 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -23,7 +23,6 @@ import csv
import numpy as np
from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.contrib.data.python.ops import optimization
from tensorflow.contrib.data.python.ops import parsing_ops
@@ -38,6 +37,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.platform import gfile
from tensorflow.python.util import deprecation
@@ -508,7 +508,7 @@ def make_csv_dataset(
_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB
-class CsvDataset(dataset_ops.Dataset):
+class CsvDataset(dataset_ops.DatasetSource):
"""A Dataset comprising lines from one or more CSV files."""
def __init__(self,
@@ -629,7 +629,7 @@ class CsvDataset(dataset_ops.Dataset):
def _as_variant_tensor(self):
# Constructs graph node for the dataset op.
- return contrib_gen_dataset_ops.csv_dataset(
+ return gen_experimental_dataset_ops.experimental_csv_dataset(
filenames=self._filenames,
record_defaults=self._record_defaults,
buffer_size=self._buffer_size,
@@ -924,7 +924,7 @@ def _get_file_names(file_pattern, shuffle):
return file_names
-class SqlDataset(dataset_ops.Dataset):
+class SqlDataset(dataset_ops.DatasetSource):
"""A `Dataset` consisting of the results from a SQL query."""
def __init__(self, driver_name, data_source_name, query, output_types):
@@ -985,7 +985,7 @@ class SqlDataset(dataset_ops.Dataset):
return self._output_types
-class LMDBDataset(dataset_ops.Dataset):
+class LMDBDataset(dataset_ops.DatasetSource):
"""A LMDB Dataset that reads the lmdb file."""
def __init__(self, filenames):
@@ -1013,7 +1013,7 @@ class LMDBDataset(dataset_ops.Dataset):
filenames, dtype=dtypes.string, name="filenames")
def _as_variant_tensor(self):
- return contrib_gen_dataset_ops.lmdb_dataset(
+ return gen_experimental_dataset_ops.experimental_lmdb_dataset(
self._filenames,
output_types=nest.flatten(self.output_types),
output_shapes=nest.flatten(self.output_shapes))
diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py
index 6b002b4a53..c52582cd35 100644
--- a/tensorflow/contrib/data/python/ops/scan_ops.py
+++ b/tensorflow/contrib/data/python/ops/scan_ops.py
@@ -27,12 +27,12 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import gen_dataset_ops
-class _ScanDataset(dataset_ops.Dataset):
+class _ScanDataset(dataset_ops.UnaryDataset):
"""A dataset that scans a function across its input."""
def __init__(self, input_dataset, initial_state, scan_func):
"""See `scan()` for details."""
- super(_ScanDataset, self).__init__()
+ super(_ScanDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
with ops.name_scope("initial_state"):
diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py
index 4356721704..985d1d87d0 100644
--- a/tensorflow/contrib/data/python/ops/shuffle_ops.py
+++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py
@@ -25,16 +25,11 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
-class _ShuffleAndRepeatDataset(dataset_ops.Dataset):
+class _ShuffleAndRepeatDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that fuses `shuffle` and `repeat`."""
- def __init__(self,
- input_dataset,
- buffer_size,
- count=None,
- seed=None):
- """See `Dataset.map()` for details."""
- super(_ShuffleAndRepeatDataset, self).__init__()
+ def __init__(self, input_dataset, buffer_size, count=None, seed=None):
+ super(_ShuffleAndRepeatDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size")
diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py
index 8025dcdd16..bcc383587c 100644
--- a/tensorflow/contrib/data/python/ops/sliding.py
+++ b/tensorflow/contrib/data/python/ops/sliding.py
@@ -26,12 +26,12 @@ from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.util import deprecation
-class _SlideDataset(dataset_ops.Dataset):
+class _SlideDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that passes a sliding window over its input."""
def __init__(self, input_dataset, window_size, window_shift, window_stride):
"""See `sliding_window_batch` for details."""
- super(_SlideDataset, self).__init__()
+ super(_SlideDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._window_size = ops.convert_to_tensor(
window_size, dtype=dtypes.int64, name="window_stride")
@@ -67,6 +67,10 @@ class _SlideDataset(dataset_ops.Dataset):
@deprecation.deprecated_args(
None, "stride is deprecated, use window_shift instead", "stride")
+@deprecation.deprecated(
+ None, "Use `tf.data.Dataset.window(size=window_size, shift=window_shift, "
+ "stride=window_stride).flat_map(lambda x: x.batch(window.size))` "
+ "instead.")
def sliding_window_batch(window_size,
stride=None,
window_shift=None,
diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py
index 8426228992..bc47c5989d 100644
--- a/tensorflow/contrib/data/python/ops/stats_ops.py
+++ b/tensorflow/contrib/data/python/ops/stats_ops.py
@@ -23,34 +23,31 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
-# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
class StatsAggregator(object):
"""A stateful resource that aggregates statistics from one or more iterators.
To record statistics, use one of the custom transformation functions defined
in this module when defining your `tf.data.Dataset`. All statistics will be
aggregated by the `StatsAggregator` that is associated with a particular
- iterator (see below). For example, to record the total number of bytes
- produced by iterating over a dataset:
+ iterator (see below). For example, to record the latency of producing each
+ element by iterating over a dataset:
```python
dataset = ...
- dataset = dataset.apply(stats_ops.bytes_produced_stats("total_bytes"))
+ dataset = dataset.apply(stats_ops.latency_stats("total_bytes"))
```
- To associate a `StatsAggregator` with a `tf.data.Iterator` object, use
+ To associate a `StatsAggregator` with a `tf.data.Dataset` object, use
the following pattern:
```python
- dataset = ...
- iterator = dataset.make_one_shot_iterator()
stats_aggregator = stats_ops.StatsAggregator()
- set_op = stats_aggregator.subscribe(iterator)
+ dataset = ...
- with tf.Session() as sess:
- # Running `set_op` will associate `iterator` with `stats_aggregator`.
- sess.run(set_op)
+ # Apply `set_stats_aggregator` to associate `dataset` with `stats_aggregator`.
+ dataset = dataset.apply(
+ tf.contrib.data.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_one_shot_iterator()
```
To get a protocol buffer summary of the currently aggregated statistics,
@@ -60,6 +57,7 @@ class StatsAggregator(object):
```python
stats_aggregator = stats_ops.StatsAggregator()
+ # ...
stats_summary = stats_aggregator.get_summary()
tf.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary)
```
@@ -73,6 +71,7 @@ class StatsAggregator(object):
"""Creates a `StatsAggregator`."""
self._resource = gen_dataset_ops.stats_aggregator_handle()
+ # TODO(b/116314787): Update this/add support for V2 summary API.
def get_summary(self):
"""Returns a string `tf.Tensor` that summarizes the aggregated statistics.
@@ -85,11 +84,11 @@ class StatsAggregator(object):
return gen_dataset_ops.stats_aggregator_summary(self._resource)
-class _SetStatsAggregatorDataset(dataset_ops.Dataset):
+class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and sets given stats_aggregator."""
def __init__(self, input_dataset, stats_aggregator):
- super(_SetStatsAggregatorDataset, self).__init__()
+ super(_SetStatsAggregatorDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._stats_aggregator = stats_aggregator
@@ -112,13 +111,11 @@ class _SetStatsAggregatorDataset(dataset_ops.Dataset):
return self._input_dataset.output_classes
-# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
def set_stats_aggregator(stats_aggregator):
- """Set the given stats_aggregator for aggregating the input dataset stats.
+ """Set the given `stats_aggregator` for aggregating the input dataset stats.
Args:
- stats_aggregator: A `StatsAggregator` object.
+ stats_aggregator: A `tf.contrib.data.StatsAggregator` object.
Returns:
A `Dataset` transformation function, which can be passed to
@@ -155,8 +152,6 @@ def bytes_produced_stats(tag):
return _apply_fn
-# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
def latency_stats(tag):
"""Records the latency of producing each element of the input dataset.
@@ -178,11 +173,11 @@ def latency_stats(tag):
return _apply_fn
-class _StatsDataset(dataset_ops.Dataset):
+class _StatsDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and also records statistics."""
def __init__(self, input_dataset, op_function, tag):
- super(_StatsDataset, self).__init__()
+ super(_StatsDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._op_function = op_function
self._tag = ops.convert_to_tensor(tag, dtype=dtypes.string)
diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py
index dc67accdcf..f73c3fd9cb 100644
--- a/tensorflow/contrib/data/python/ops/threadpool.py
+++ b/tensorflow/contrib/data/python/ops/threadpool.py
@@ -19,10 +19,9 @@ from __future__ import print_function
import threading
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.ops import resource_variable_ops
_uid_counter = 0
@@ -47,7 +46,7 @@ class PrivateThreadPool(object):
"""Creates a `PrivateThreadPool` with the given number of threads."""
if context.executing_eagerly():
shared_name = _generate_shared_name("privatethreadpool")
- self._resource = gen_dataset_ops.thread_pool_handle(
+ self._resource = ged_ops.experimental_thread_pool_handle(
num_threads=num_threads,
max_intra_op_parallelism=max_intra_op_parallelism,
display_name=display_name,
@@ -55,22 +54,22 @@ class PrivateThreadPool(object):
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
handle=self._resource, handle_device=context.context().device_name)
else:
- self._resource = gen_dataset_ops.thread_pool_handle(
+ self._resource = ged_ops.experimental_thread_pool_handle(
num_threads=num_threads,
max_intra_op_parallelism=max_intra_op_parallelism,
display_name=display_name)
-class _ThreadPoolDataset(dataset_ops.Dataset):
+class _ThreadPoolDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and sets a custom threadpool."""
def __init__(self, input_dataset, thread_pool):
- super(_ThreadPoolDataset, self).__init__()
+ super(_ThreadPoolDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._thread_pool = thread_pool
def _as_variant_tensor(self):
- return gen_dataset_ops.thread_pool_dataset(
+ return ged_ops.experimental_thread_pool_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._thread_pool._resource, # pylint: disable=protected-access
**dataset_ops.flat_structure(self))
diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py
index e0d606311c..ed363a7090 100644
--- a/tensorflow/contrib/data/python/ops/unique.py
+++ b/tensorflow/contrib/data/python/ops/unique.py
@@ -17,10 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import gen_experimental_dataset_ops
def unique():
@@ -47,12 +46,12 @@ def unique():
return _apply_fn
-class _UniqueDataset(dataset_ops.Dataset):
+class _UniqueDataset(dataset_ops.UnaryDataset):
"""A `Dataset` contains the unique elements from its input."""
def __init__(self, input_dataset):
"""See `unique()` for details."""
- super(_UniqueDataset, self).__init__()
+ super(_UniqueDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
dtypes.string):
@@ -61,7 +60,7 @@ class _UniqueDataset(dataset_ops.Dataset):
"`tf.int32`, `tf.int64`, or `tf.string` component.")
def _as_variant_tensor(self):
- return gen_dataset_ops.unique_dataset(
+ return gen_experimental_dataset_ops.experimental_unique_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
**dataset_ops.flat_structure(self))
diff --git a/tensorflow/contrib/deprecated/summaries_test.py b/tensorflow/contrib/deprecated/summaries_test.py
index 6acf2a6469..4038224a1c 100644
--- a/tensorflow/contrib/deprecated/summaries_test.py
+++ b/tensorflow/contrib/deprecated/summaries_test.py
@@ -27,31 +27,31 @@ from tensorflow.python.platform import test
class DeprecatedSummariesTest(test.TestCase):
def testScalarSummary(self):
- with self.test_session():
+ with self.cached_session():
c = constant_op.constant(3)
s = logging_ops.scalar_summary('tag', c)
self.assertEqual(s.op.type, u'ScalarSummary')
def testHistogramSummary(self):
- with self.test_session():
+ with self.cached_session():
c = constant_op.constant(3)
s = logging_ops.histogram_summary('tag', c)
self.assertEqual(s.op.type, u'HistogramSummary')
def testImageSummary(self):
- with self.test_session():
+ with self.cached_session():
i = array_ops.ones((5, 4, 4, 3))
s = logging_ops.image_summary('tag', i)
self.assertEqual(s.op.type, u'ImageSummary')
def testAudioSummary(self):
- with self.test_session():
+ with self.cached_session():
c = constant_op.constant(3.0)
s = logging_ops.audio_summary('tag', c, sample_rate=8000)
self.assertEqual(s.op.type, u'AudioSummaryV2')
def testMergeSummary(self):
- with self.test_session():
+ with self.cached_session():
c = constant_op.constant(3)
a = logging_ops.scalar_summary('a', c)
b = logging_ops.scalar_summary('b', c)
diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md
index 91a27f97b7..2e025765e4 100644
--- a/tensorflow/contrib/distribute/README.md
+++ b/tensorflow/contrib/distribute/README.md
@@ -231,7 +231,8 @@ The same `input_fn` will be used for all workers if you use
important to shuffle your dataset in your `input_fn`.
`MirroredStrategy` will insert a `tf.dataset.Dataset.shard` call in you
-`input_fn`. As a result, each worker gets a fraction of your input data.
+`input_fn` if `auto_shard_dataset` is set to `True`. As a result, each worker
+gets a fraction of your input data.
### Performance Tips
diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py
index 350f81f60f..823fe6a917 100644
--- a/tensorflow/contrib/distribute/__init__.py
+++ b/tensorflow/contrib/distribute/__init__.py
@@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Prototype of a distributed computation library for TF."""
+"""A distributed computation library for TF.
+
+See [tensorflow/contrib/distribute/README.md](
+https://www.tensorflow.org/code/tensorflow/contrib/distribute/README.md)
+for overview and examples.
+"""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index f72b827e04..e329b964c4 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -22,7 +22,6 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
":input_ops",
- ":prefetching_ops_v2",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:device_util",
@@ -30,6 +29,7 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
+ "//tensorflow/python/data/ops:multi_device_iterator_ops",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
"@six_archive//:six",
@@ -453,7 +453,7 @@ cuda_py_test(
cuda_py_test(
name = "estimator_training_test",
- size = "large",
+ size = "enormous",
srcs = ["estimator_training_test.py"],
additional_deps = [
":combinations",
@@ -472,11 +472,8 @@ cuda_py_test(
"//tensorflow/python:summary",
],
tags = [
- "manual",
"multi_and_single_gpu",
"no_pip",
- "nogpu",
- "notap",
],
)
@@ -651,32 +648,6 @@ cuda_py_test(
)
py_library(
- name = "prefetching_ops_v2",
- srcs = ["prefetching_ops_v2.py"],
- deps = [
- "//tensorflow/contrib/data/python/ops:contrib_op_loader",
- "//tensorflow/contrib/data/python/ops:prefetching_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
- ],
-)
-
-cuda_py_test(
- name = "prefetching_ops_v2_test",
- srcs = ["prefetching_ops_v2_test.py"],
- additional_deps = [
- ":prefetching_ops_v2",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- ],
-)
-
-py_library(
name = "input_ops",
srcs = ["input_ops.py"],
visibility = ["//tensorflow:internal"],
@@ -731,12 +702,9 @@ cuda_py_test(
":keras_test_lib",
],
tags = [
- "manual",
"multi_and_single_gpu",
- "no_gpu",
"no_pip",
"no_windows_gpu",
- "notap",
"notsan",
],
)
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
index 77079d0df9..9809204f8f 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -143,8 +143,10 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
def _real_mirrored_creator(devices, *args, **kwargs):
"""Creates one MirroredVariable on the current worker."""
index = {}
+ unique_var_name = ops.get_default_graph().unique_name(
+ kwargs["name"], mark_as_used=False).rstrip("/")
collective_instance_key = self._collective_keys.get_instance_key(
- key_id=kwargs["name"])
+ key_id=unique_var_name)
if "initial_value" not in kwargs:
raise ValueError("Initial value must be specified.")
initial_value = kwargs["initial_value"]
@@ -188,6 +190,10 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
v = next_creator(*args, **kwargs)
+ if i == 0:
+ actual_var_name = v.name.split(":")[0]
+ assert unique_var_name == actual_var_name, "%r vs %r" % (
+ unique_var_name, actual_var_name)
assert not isinstance(v, values.DistributedVariable)
index[d] = v
return index
@@ -210,7 +216,7 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
"""Configures the object.
Args:
- session_config: a @{tf.ConfigProto}
+ session_config: a `tf.ConfigProto`
cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
cluster configurations.
task_type: the current task type, such as "worker".
@@ -229,8 +235,6 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
if not session_config or not self._cluster_spec:
return
- session_config.isolate_session_state = True
-
assert self._task_type
assert self._task_id is not None
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
index 36e9761073..33ffbf6abe 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
@@ -26,6 +26,7 @@ from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -34,9 +35,14 @@ from tensorflow.python.layers import core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import test
+from tensorflow.python.training import adam
+from tensorflow.python.training import training_util
class CollectiveAllReduceStrategyTestBase(
@@ -146,6 +152,56 @@ class CollectiveAllReduceStrategyTestBase(
self.assertLess(error_after, error_before)
return error_after < error_before
+ def _test_complex_model(self, task_type, task_id, num_gpus):
+ d, master_target = self._get_test_object(task_type, task_id, num_gpus)
+
+ def model_fn():
+ """Mnist model with synthetic input."""
+ data_format = 'channels_last'
+ input_shape = [28, 28, 1]
+ l = keras.layers
+ max_pool = l.MaxPooling2D((2, 2), (2, 2),
+ padding='same',
+ data_format=data_format)
+ model = keras.Sequential([
+ l.Reshape(target_shape=input_shape, input_shape=(28 * 28,)),
+ l.Conv2D(
+ 32,
+ 5,
+ padding='same',
+ data_format=data_format,
+ activation=nn.relu), max_pool,
+ l.Conv2D(
+ 64,
+ 5,
+ padding='same',
+ data_format=data_format,
+ activation=nn.relu), max_pool,
+ l.Flatten(),
+ l.Dense(1024, activation=nn.relu),
+ l.Dropout(0.4),
+ l.Dense(10)
+ ])
+ image = random_ops.random_uniform([2, 28, 28])
+ label = random_ops.random_uniform([2, 1], maxval=10, dtype=dtypes.int32)
+ logits = model(image, training=True)
+ loss = losses.sparse_softmax_cross_entropy(labels=label, logits=logits)
+ optimizer = adam.AdamOptimizer(learning_rate=1e-4)
+ train_op = optimizer.minimize(loss,
+ training_util.get_or_create_global_step())
+ return train_op
+
+ with ops.Graph().as_default(), \
+ self.test_session(config=self._sess_config,
+ target=master_target) as sess:
+ with d.scope():
+ train_op = d.call_for_each_tower(model_fn)
+ train_op = d.group(d.unwrap(train_op))
+
+ sess.run(variables.global_variables_initializer())
+ sess.run(train_op)
+ return True
+
def _test_variable_initialization(self, task_type, task_id, num_gpus):
distribution, master_target = self._get_test_object(task_type, task_id,
num_gpus)
@@ -206,6 +262,14 @@ class DistributedCollectiveAllReduceStrategyTest(
self._cluster_spec,
num_gpus=num_gpus)
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
+ def testComplexModel(self, num_gpus):
+ if context.num_gpus() < num_gpus:
+ return
+ self._run_between_graph_clients(
+ self._test_complex_model, self._cluster_spec, num_gpus=num_gpus)
+
class DistributedCollectiveAllReduceStrategyTestWithChief(
CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
@@ -236,6 +300,14 @@ class DistributedCollectiveAllReduceStrategyTestWithChief(
self._cluster_spec,
num_gpus=num_gpus)
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
+ def testComplexModel(self, num_gpus):
+ if context.num_gpus() < num_gpus:
+ return
+ self._run_between_graph_clients(
+ self._test_complex_model, self._cluster_spec, num_gpus=num_gpus)
+
class LocalCollectiveAllReduceStrategy(
CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
@@ -246,6 +318,12 @@ class LocalCollectiveAllReduceStrategy(
return
self._test_minimize_loss_graph(None, None, num_gpus)
+ def testComplexModel(self, num_gpus=2):
+ # Collective ops doesn't support strategy with one device.
+ if context.num_gpus() < num_gpus:
+ return
+ self._test_complex_model(None, None, num_gpus)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index 244d1fcec8..82ca041cc2 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -59,6 +59,7 @@ from tensorflow.python.training import adagrad
from tensorflow.python.training import adam
from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import rmsprop
from tensorflow.python.util import tf_inspect
@@ -354,6 +355,8 @@ gradient_descent_optimizer_v1_fn = NamedObject(
"GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2))
adagrad_optimizer_v1_fn = NamedObject(
"AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
+rmsprop_optimizer_v1_fn = NamedObject(
+ "RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001))
optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn,
adagrad_optimizer_v1_fn]
diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py
index 24cb08fb48..9fc1b88955 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_utils.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py
@@ -221,9 +221,12 @@ def split_grads_by_size(threshold_size, device_grads):
return small_grads, large_grads
-# threading.Lock() cannot be pickled and therefore cannot be a field of
-# CollectiveKeys.
+# threading.Lock() and threading.local() cannot be pickled and therefore cannot
+# be a field of CollectiveKeys. Right now _thread_local is not necessary to be
+# an instance member of CollectiveKeys since we always create a new thread for
+# each tower.
_lock = threading.Lock()
+_thread_local = threading.local()
# TODO(yuefengz): use random key starts to avoid reusing keys?
@@ -266,14 +269,12 @@ class CollectiveKeys(object):
# For instance keys without ids
self._instance_key_start = instance_key_start
- self._thread_local = threading.local()
-
def _get_thread_local_object(self):
# We make instance key without key ids thread local so that it will work
# with MirroredStrategy and distribute coordinator.
- if not hasattr(self._thread_local, 'instance_key'):
- self._thread_local.instance_key = self._instance_key_start
- return self._thread_local
+ if not hasattr(_thread_local, 'instance_key'):
+ _thread_local.instance_key = self._instance_key_start
+ return _thread_local
def get_group_key(self, devices):
"""Returns a group key for the set of devices.
diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py
index 5348512016..157618f72f 100644
--- a/tensorflow/contrib/distribute/python/estimator_training_test.py
+++ b/tensorflow/contrib/distribute/python/estimator_training_test.py
@@ -26,21 +26,12 @@ import tempfile
import threading
from absl.testing import parameterized
import numpy as np
-import six
-_portpicker_import_error = None
-try:
- import portpicker # pylint: disable=g-import-not-at-top
-except ImportError as _error: # pylint: disable=invalid-name
- _portpicker_import_error = _error
- portpicker = None
-
-# pylint: disable=g-import-not-at-top
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import parameter_server_strategy
from tensorflow.contrib.optimizer_v2 import adagrad
-from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import estimator_training as dc_training
@@ -57,7 +48,6 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary import summary_iterator
from tensorflow.python.summary.writer import writer_cache
-from tensorflow.python.training import server_lib
BATCH_SIZE = 10
LABEL_DIMENSION = 2
@@ -73,130 +63,38 @@ EVALUATOR = dc._TaskType.EVALUATOR
WORKER = dc._TaskType.WORKER
PS = dc._TaskType.PS
-original_run_distribute_coordinator = dc.run_distribute_coordinator
-
-
-# TODO(yuefengz): merge this method back to test_util.
-def _create_local_cluster(num_workers,
- num_ps,
- has_eval=False,
- protocol="grpc",
- worker_config=None,
- ps_config=None):
- if _portpicker_import_error:
- raise _portpicker_import_error # pylint: disable=raising-bad-type
- worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
- ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
-
- cluster_dict = {
- "worker": ["localhost:%s" % port for port in worker_ports],
- "ps": ["localhost:%s" % port for port in ps_ports]
- }
- if has_eval:
- cluster_dict["evaluator"] = ["localhost:%s" % portpicker.pick_unused_port()]
-
- cs = server_lib.ClusterSpec(cluster_dict)
-
- workers = [
- server_lib.Server(
- cs,
- job_name="worker",
- protocol=protocol,
- task_index=ix,
- config=worker_config,
- start=True) for ix in range(num_workers)
- ]
- ps_servers = [
- server_lib.Server(
- cs,
- job_name="ps",
- protocol=protocol,
- task_index=ix,
- config=ps_config,
- start=True) for ix in range(num_ps)
- ]
- if has_eval:
- evals = [
- server_lib.Server(
- cs,
- job_name="evaluator",
- protocol=protocol,
- task_index=0,
- config=worker_config,
- start=True)
- ]
- else:
- evals = []
-
- return workers, ps_servers, evals
-
-
-def _create_in_process_cluster(num_workers, num_ps, has_eval=False):
- """Create an in-process cluster that consists of only standard server."""
- # Leave some memory for cuda runtime.
- if has_eval:
- gpu_mem_frac = 0.7 / (num_workers + 1)
- else:
- gpu_mem_frac = 0.7 / num_workers
-
- worker_config = config_pb2.ConfigProto()
- worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
-
- # Enable collective ops which has no impact on non-collective ops.
- # TODO(yuefengz, tucker): removing this after we move the initialization of
- # collective mgr to the session level.
- worker_config.experimental.collective_group_leader = (
- "/job:worker/replica:0/task:0")
-
- ps_config = config_pb2.ConfigProto()
- ps_config.device_count["GPU"] = 0
-
- return _create_local_cluster(
- num_workers,
- num_ps=num_ps,
- has_eval=has_eval,
- worker_config=worker_config,
- ps_config=ps_config,
- protocol="grpc")
-
-
-def _create_cluster_spec(has_chief=False,
- num_workers=1,
- num_ps=0,
- has_eval=False):
- if _portpicker_import_error:
- raise _portpicker_import_error # pylint: disable=raising-bad-type
-
- cluster_spec = {}
- if has_chief:
- cluster_spec[CHIEF] = ["localhost:%s" % portpicker.pick_unused_port()]
- if num_workers:
- cluster_spec[WORKER] = [
- "localhost:%s" % portpicker.pick_unused_port()
- for _ in range(num_workers)
- ]
- if num_ps:
- cluster_spec[PS] = [
- "localhost:%s" % portpicker.pick_unused_port() for _ in range(num_ps)
- ]
- if has_eval:
- cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()]
- return cluster_spec
+original_run_std_server = dc._run_std_server
-def _bytes_to_str(maybe_bytes):
- if isinstance(maybe_bytes, six.string_types):
- return maybe_bytes
- else:
- return str(maybe_bytes, "utf-8")
+class MockOsEnv(dict):
+
+ def __init__(self, *args):
+ self._thread_local = threading.local()
+ super(MockOsEnv, self).__init__(*args)
+
+ def get(self, key, default):
+ if not hasattr(self._thread_local, "dict"):
+ self._thread_local.dict = dict()
+ if key == "TF_CONFIG":
+ return dict.get(self._thread_local.dict, key, default)
+ else:
+ return dict.get(self, key, default)
+ def __getitem__(self, key):
+ if not hasattr(self._thread_local, "dict"):
+ self._thread_local.dict = dict()
+ if key == "TF_CONFIG":
+ return dict.__getitem__(self._thread_local.dict, key)
+ else:
+ return dict.__getitem__(self, key)
-def _strip_protocol(target):
- # cluster_spec expects "host:port" strings.
- if "//" in target:
- return target.split("//")[1]
- else:
- return target
+ def __setitem__(self, key, val):
+ if not hasattr(self._thread_local, "dict"):
+ self._thread_local.dict = dict()
+ if key == "TF_CONFIG":
+ return dict.__setitem__(self._thread_local.dict, key, val)
+ else:
+ return dict.__setitem__(self, key, val)
class DistributeCoordinatorIntegrationTest(test.TestCase,
@@ -205,22 +103,20 @@ class DistributeCoordinatorIntegrationTest(test.TestCase,
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers."""
- cls._workers, cls._ps, cls._evals = _create_in_process_cluster(
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
num_workers=3, num_ps=2, has_eval=True)
- cls._cluster_spec = {
- "worker": [
- _strip_protocol(_bytes_to_str(w.target)) for w in cls._workers
- ],
- "ps": [_strip_protocol(_bytes_to_str(ps.target)) for ps in cls._ps],
- "evaluator": [
- _strip_protocol(_bytes_to_str(e.target)) for e in cls._evals
- ]
- }
def setUp(self):
self._model_dir = tempfile.mkdtemp()
- self._event = threading.Event()
+ self._mock_os_env = MockOsEnv()
+ self._mock_context = test.mock.patch.object(os, "environ",
+ self._mock_os_env)
super(DistributeCoordinatorIntegrationTest, self).setUp()
+ self._mock_context.__enter__()
+
+ def tearDown(self):
+ self._mock_context.__exit__(None, None, None)
+ super(DistributeCoordinatorIntegrationTest, self).tearDown()
def dataset_input_fn(self, x, y, batch_size, shuffle):
@@ -391,43 +287,17 @@ class DistributeCoordinatorIntegrationTest(test.TestCase,
train_distribute, eval_distribute, remote_cluster=self._cluster_spec)
self._inspect_train_and_eval_events(estimator)
- def _mock_run_distribute_coordinator(
- self,
- worker_fn,
- strategy,
- eval_fn,
- eval_strategy,
- mode=dc.CoordinatorMode.STANDALONE_CLIENT,
- cluster_spec=None,
- session_config=None):
- # Calls the origial `run_distribute_coordinator` method but gets task config
- # from environment variables and then signals the caller.
- task_type = None
- task_id = None
- if not cluster_spec:
- cluster_spec = None
- tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
- if not cluster_spec:
- cluster_spec = tf_config.get("cluster", {})
- task_env = tf_config.get("task", {})
- if task_env:
- task_type = task_env.get("type", task_type)
- task_id = int(task_env.get("index", task_id))
- self._event.set()
- original_run_distribute_coordinator(
- worker_fn,
- strategy,
- eval_fn,
- eval_strategy,
- mode=mode,
- cluster_spec=cluster_spec,
- task_type=task_type,
- task_id=task_id,
- session_config=session_config)
-
- def _task_thread(self, train_distribute, eval_distribute):
- with test.mock.patch.object(dc, "run_distribute_coordinator",
- self._mock_run_distribute_coordinator):
+ def _mock_run_std_server(self, *args, **kwargs):
+ ret = original_run_std_server(*args, **kwargs)
+ # Wait for all std servers to be brought up in order to reduce the chance of
+ # remote sessions taking local ports that have been assigned to std servers.
+ self._barrier.wait()
+ return ret
+
+ def _task_thread(self, train_distribute, eval_distribute, tf_config):
+ os.environ["TF_CONFIG"] = json.dumps(tf_config)
+ with test.mock.patch.object(dc, "_run_std_server",
+ self._mock_run_std_server):
self._complete_flow(train_distribute, eval_distribute)
def _run_task_in_thread(self, cluster_spec, task_type, task_id,
@@ -448,13 +318,10 @@ class DistributeCoordinatorIntegrationTest(test.TestCase,
"index": task_id
}
}
- self._event.clear()
t = threading.Thread(
- target=self._task_thread, args=(train_distribute, eval_distribute))
- with test.mock.patch.dict("os.environ",
- {"TF_CONFIG": json.dumps(tf_config)}):
- t.start()
- self._event.wait()
+ target=self._task_thread,
+ args=(train_distribute, eval_distribute, tf_config))
+ t.start()
return t
def _run_multiple_tasks_in_threads(self, cluster_spec, train_distribute,
@@ -489,7 +356,11 @@ class DistributeCoordinatorIntegrationTest(test.TestCase,
else:
eval_distribute = None
- cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
+ cluster_spec = multi_worker_test_base.create_cluster_spec(
+ num_workers=3, num_ps=2, has_eval=True)
+ # 3 workers, 2 ps and 1 evaluator.
+ self._barrier = dc._Barrier(6)
+
threads = self._run_multiple_tasks_in_threads(
cluster_spec, train_distribute, eval_distribute)
for task_type, ts in threads.items():
@@ -516,7 +387,10 @@ class DistributeCoordinatorIntegrationTest(test.TestCase,
else:
eval_distribute = None
- cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
+ cluster_spec = multi_worker_test_base.create_cluster_spec(
+ num_workers=3, num_ps=0, has_eval=True)
+ # 3 workers and 1 evaluator.
+ self._barrier = dc._Barrier(4)
threads = self._run_multiple_tasks_in_threads(
cluster_spec, train_distribute, eval_distribute)
threads[WORKER][0].join()
diff --git a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py
index 44a69ed23a..79a9803d75 100644
--- a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py
+++ b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py
@@ -22,6 +22,8 @@ from __future__ import print_function
import tensorflow as tf
+from tensorflow.python.keras import metrics as metrics_module
+
def build_model_fn_optimizer():
"""Simple model_fn with optimizer."""
@@ -45,7 +47,10 @@ def build_model_fn_optimizer():
return y * y
if mode == tf.estimator.ModeKeys.EVAL:
- return tf.estimator.EstimatorSpec(mode, loss=loss_fn())
+ acc_obj = metrics_module.BinaryAccuracy()
+ acc_obj.update_state(labels, labels)
+ return tf.estimator.EstimatorSpec(
+ mode, loss=loss_fn(), eval_metric_ops={"Accuracy": acc_obj})
assert mode == tf.estimator.ModeKeys.TRAIN
@@ -61,18 +66,26 @@ def main(_):
["/device:GPU:0", "/device:GPU:1"])
config = tf.estimator.RunConfig(train_distribute=distribution,
eval_distribute=distribution)
+ # Since there are 2 devices and 10 samples, we set steps=5.
+ steps = 5
- def input_fn():
+ def train_input_fn():
features = tf.data.Dataset.from_tensors([[1.]]).repeat(10)
labels = tf.data.Dataset.from_tensors([1.]).repeat(10)
return tf.data.Dataset.zip((features, labels))
estimator = tf.estimator.Estimator(
model_fn=build_model_fn_optimizer(), config=config)
- estimator.train(input_fn=input_fn, steps=10)
+ estimator.train(input_fn=train_input_fn, steps=steps)
+
+ def eval_input_fn():
+ features = tf.data.Dataset.from_tensors([[1.]]).repeat(10)
+ labels = tf.data.Dataset.from_tensors([1.]).repeat(10)
+ return tf.data.Dataset.zip((features, labels))
- eval_result = estimator.evaluate(input_fn=input_fn, steps=10)
+ eval_result = estimator.evaluate(input_fn=eval_input_fn, steps=steps)
print("Eval result: {}".format(eval_result))
+ assert eval_result["Accuracy"] == 1.0
def predict_input_fn():
predict_features = tf.data.Dataset.from_tensors([[1.]]).repeat(10)
diff --git a/tensorflow/contrib/distribute/python/input_ops_test.py b/tensorflow/contrib/distribute/python/input_ops_test.py
index c5acb7ced4..559de97bb1 100644
--- a/tensorflow/contrib/distribute/python/input_ops_test.py
+++ b/tensorflow/contrib/distribute/python/input_ops_test.py
@@ -20,8 +20,6 @@ from __future__ import print_function
import os
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.contrib.distribute.python import input_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
@@ -126,20 +124,6 @@ class AutoShardDatasetTest(test.TestCase):
# contain records in order of files.
self._verifySimpleShardingOutput(dataset, self._record)
- def testParallelInterleave(self):
- dataset = dataset_ops.Dataset.from_tensor_slices(
- self._createTFRecordFiles())
- dataset = dataset.apply(interleave_ops.parallel_interleave(
- readers.TFRecordDataset,
- cycle_length=4,
- block_length=self._num_records))
- dataset = input_ops.auto_shard_dataset(
- dataset, self._num_shards, self._shard_index)
-
- # Since block_length == num records in each file, the output will still
- # contain records in order of files.
- self._verifySimpleShardingOutput(dataset, self._record)
-
def testListfiles(self):
filenames = self._createTFRecordFiles()
file_pattern = filenames[0].rsplit("/", 1)[0] + "/tf_record.*.txt"
@@ -171,8 +155,8 @@ class AutoShardDatasetTest(test.TestCase):
dataset = dataset.prefetch(buffer_size=batch_size)
dataset = dataset.shuffle(2 * self._num_files * self._num_records)
dataset = dataset.repeat(num_epochs)
- dataset = dataset.apply(batching.map_and_batch(
- lambda x: x, batch_size=batch_size))
+ dataset = dataset.map(lambda x: x)
+ dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(buffer_size=None)
# Auto shard.
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index 5f35e38189..3aab2c521f 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -173,13 +173,42 @@ def batch_wrapper(dataset, batch_size, distribution):
return dataset.batch(batch_size)
-def all_combinations():
+def get_model():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+ return model
+
+
+def get_dataset(distribution):
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = batch_wrapper(dataset, 10, distribution)
+ return dataset
+
+
+strategies = [combinations.default_strategy,
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu,
+ combinations.mirrored_strategy_with_two_gpus,
+ combinations.tpu_strategy_one_step]
+
+
+def strategy_combinations():
return combinations.combine(
- distribution=[combinations.default_strategy,
- combinations.one_device_strategy,
- combinations.mirrored_strategy_with_gpu_and_cpu,
- combinations.mirrored_strategy_with_two_gpus,
- combinations.tpu_strategy_one_step],
+ distribution=strategies,
+ mode=['graph'])
+
+
+def strategy_and_optimizer_combinations():
+ return combinations.combine(
+ distribution=strategies,
+ optimizer=[combinations.adagrad_optimizer_v1_fn,
+ combinations.adam_optimizer_v1_fn,
+ combinations.gradient_descent_optimizer_v1_fn,
+ combinations.rmsprop_optimizer_v1_fn],
mode=['graph'])
@@ -205,6 +234,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
keras_model = simple_functional_model()
keras_model.compile(
loss='categorical_crossentropy',
+ metrics=[keras.metrics.CategoricalAccuracy()],
optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01))
config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
model_dir=self._base_dir,
@@ -229,6 +259,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
keras_model = simple_sequential_model()
keras_model.compile(
loss='categorical_crossentropy',
+ metrics=[keras.metrics.CategoricalAccuracy()],
optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01))
config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
model_dir=self._base_dir,
@@ -358,13 +389,11 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
def test_calling_model_with_numpy_arrays(self):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
- metrics = ['mae']
+ metrics = ['mae', keras.metrics.CategoricalAccuracy()]
strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1',
'/device:GPU:0'])
model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
@@ -390,23 +419,17 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
# with batch_size
model.predict(inputs, batch_size=8)
- @combinations.generate(all_combinations())
+ @combinations.generate(strategy_combinations())
def test_calling_model_on_same_dataset(self, distribution):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
- metrics = ['mae']
+ metrics = ['mae', keras.metrics.CategoricalAccuracy()]
model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = batch_wrapper(dataset, 10, distribution)
+ dataset = get_dataset(distribution)
# Call fit with validation data
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
@@ -432,7 +455,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001)
loss = 'mse'
- metrics = ['mae']
+ metrics = ['mae', keras.metrics.CategoricalAccuracy()]
strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
'/device:CPU:0'])
model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
@@ -459,23 +482,17 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1)
- @combinations.generate(all_combinations())
+ @combinations.generate(strategy_combinations())
def test_fit_eval_and_predict_methods_on_dataset(self, distribution):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
- metrics = ['mae']
+ metrics = ['mae', keras.metrics.CategoricalAccuracy()]
model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = batch_wrapper(dataset, 10, distribution)
+ dataset = get_dataset(distribution)
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
model.evaluate(dataset, steps=2, verbose=1)
@@ -484,37 +501,23 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
validation_data=dataset, validation_steps=2)
- def test_raise_error_for_stateful_metrics(self):
-
- class ExampleStatefulMetric(keras.layers.Layer):
-
- def __init__(self, name='true_positives', **kwargs):
- super(ExampleStatefulMetric, self).__init__(name=name, **kwargs)
- self.stateful = True
-
- def __call__(self, y_true, y_pred):
- return y_pred - y_true
-
+ @combinations.generate(strategy_and_optimizer_combinations())
+ def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
- optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
- metrics = ['mae', ExampleStatefulMetric()]
- strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1',
- '/device:GPU:0'])
- with self.assertRaisesRegexp(
- NotImplementedError, 'Stateful metrics are not supported with '
- 'DistributionStrategy.'):
- model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+ model.compile(optimizer(), loss, distribute=distribution)
+
+ dataset = get_dataset(distribution)
+
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
+ model.evaluate(dataset, steps=2, verbose=1)
+ model.predict(dataset, steps=2)
def test_unsupported_features(self):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
@@ -524,11 +527,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
+ dataset = get_dataset(strategy)
# Test with validation split
with self.assertRaisesRegexp(
@@ -565,9 +564,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
def test_calling_with_unsupported_predefined_callbacks(self):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
@@ -576,11 +573,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
'/device:GPU:0'])
model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
+ dataset = get_dataset(strategy)
def schedule(_):
return 0.001
@@ -604,9 +597,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
def test_dataset_input_shape_validation(self):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
loss = 'mse'
@@ -635,6 +626,25 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
'expected input to have shape'):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
+ @combinations.generate(combinations.combine(
+ distribution=[combinations.tpu_strategy_one_step],
+ mode=['graph']))
+ def test_dataset_input_shape_fully_defined(self, distribution):
+ with self.cached_session():
+ model = get_model()
+
+ optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ model.compile(optimizer, loss, distribute=distribution)
+
+ dataset = get_dataset(distribution)
+ # Input shapes are not fully known. Batch dimension is unknown as we are
+ # not using the drop_remainder argument.
+ dataset = dataset.repeat(100).batch(10)
+
+ with self.assertRaisesRegexp(ValueError, 'requires fully defined shapes'):
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
+
def test_learning_phase_value(self):
# TODO(anjalisridhar): Modify this test to use Lambdas since we can compare
# meaningful values. Currently we don't pass the learning phase if the
@@ -699,7 +709,7 @@ class LossMaskingWithDistributionStrategyTest(test.TestCase):
class NormalizationLayerWithDistributionStrategyTest(
test.TestCase, parameterized.TestCase):
- @combinations.generate(all_combinations())
+ @combinations.generate(strategy_combinations())
def test_batchnorm_correctness(self, distribution):
with self.cached_session():
model = keras.models.Sequential()
@@ -727,19 +737,57 @@ class NormalizationLayerWithDistributionStrategyTest(
class CorrectnessWithDistributionStrategyTest(test.TestCase,
parameterized.TestCase):
- @combinations.generate(all_combinations())
+ @combinations.generate(strategy_combinations())
+ def test_metric_correctness(self, distribution):
+ with self.cached_session():
+ keras.backend.set_image_data_format('channels_last')
+ num_samples = 10000
+
+ x_train = np.random.randint(0, 2, num_samples)
+ x_train = np.reshape(x_train, (num_samples, 1))
+ y_train = x_train
+ x_train = x_train.astype('float32')
+ y_train = y_train.astype('float32')
+
+ # Create identity model.
+ model = keras.Sequential()
+ model.add(
+ keras.layers.Dense(1, input_shape=(1,), kernel_initializer='ones'))
+ model.compile(
+ loss=keras.losses.mean_squared_error,
+ optimizer=gradient_descent.GradientDescentOptimizer(0.5),
+ metrics=[keras.metrics.BinaryAccuracy()],
+ distribute=distribution)
+
+ batch_size = 64
+ batch_size //= distribution.num_towers
+ train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train))
+ train_dataset = batch_wrapper(train_dataset, batch_size, distribution)
+
+ history = model.fit(x=train_dataset, epochs=1, steps_per_epoch=10)
+ self.assertEqual(history.history['binary_accuracy'], [1.0])
+
+ @combinations.generate(strategy_combinations())
def test_correctness(self, distribution):
with self.cached_session():
keras.backend.set_image_data_format('channels_last')
num_samples = 10000
+
+ # Train and predict datasets are created with the same input numpy arrays.
x_train = np.random.rand(num_samples, 1)
y_train = 3 * x_train
x_train = x_train.astype('float32')
y_train = y_train.astype('float32')
+ # The model is built once and the initial weights are saved.
+ # This is used to initialize the model for both the distribution and
+ # non-distribution run.
+ model = keras.Sequential()
+ model.add(keras.layers.Dense(1, input_shape=(1,)))
+ initial_weights = model.get_weights()
+
def fit_and_predict(with_distribution=None):
- model = keras.Sequential()
- model.add(keras.layers.Dense(1, input_shape=(1,)))
+ model.set_weights(initial_weights)
model.compile(
loss=keras.losses.mean_squared_error,
optimizer=gradient_descent.GradientDescentOptimizer(0.5),
@@ -751,12 +799,14 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase,
train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train,
y_train))
train_dataset = batch_wrapper(train_dataset, batch_size, distribution)
- # Running only 100 steps instead of the full dataset to keep test
- # duration small.
- model.fit(x=train_dataset, epochs=1, steps_per_epoch=100)
+ # We have initialized the model to the same weight for the distribution
+ # and non-distribution run. If you want to initialize the model to
+ # random weights for each run, you need to run the model through the
+ # entire dataset at least once to ensure that the weights converge to
+ # the same value.
+ model.fit(x=train_dataset, epochs=1, steps_per_epoch=10)
weights = model.get_weights()
-
x_predict = [[1.], [2.], [3.], [4.]]
predict_batch_size = 4
if with_distribution:
diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py
index 8163494c8e..f7773aff4f 100644
--- a/tensorflow/contrib/distribute/python/metrics_v1_test.py
+++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py
@@ -86,10 +86,11 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
with ops.Graph().as_default(), distribution.scope():
iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ dataset_fn).make_initializable_iterator()
value, update = distribution.call_for_each_tower(
metric_fn, iterator.get_next())
update = distribution.group(update)
+ self.evaluate(iterator.initializer)
self.evaluate(variables.local_variables_initializer())
# TODO(josh11b): Once we switch to using a global batch size for input,
# replace "distribution.num_towers" with "1".
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index ba147e7824..d082d5c419 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -41,6 +41,14 @@ from tensorflow.python.ops.losses import losses_impl
class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
+ def _get_iterator(self, ds):
+ if context.executing_eagerly():
+ iterator = ds.make_one_shot_iterator()
+ else:
+ iterator = ds.make_initializable_iterator()
+ self.evaluate(iterator.initializer)
+ return iterator
+
@combinations.generate(
combinations.times(
combinations.distributions_and_v1_optimizers(),
@@ -62,8 +70,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, *inputs, run_concurrently=layer.built))
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
return distribution.run_steps_on_dataset(
@@ -99,8 +106,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
return distribution.group(
@@ -159,8 +165,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, *inputs, run_concurrently=layer.built))
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
return distribution.run_steps_on_dataset(
@@ -244,8 +249,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS)
return control_flow_ops.group(fetches)
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
return distribution.run_steps_on_dataset(
@@ -338,8 +342,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, x, y, run_concurrently=False))
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
return distribution.run_steps_on_dataset(
@@ -432,8 +435,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
output=loss)
return distribution.group(train_op)
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
initial_loss = lambda: constant_op.constant(1e7)
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 0c6805d682..93d42e09a2 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -347,6 +347,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
set, the `configure` method will try to find the best one.
prefetch_on_device: optional boolean to specify whether to prefetch input
data to devices.
+ auto_shard_dataset: whether to auto-shard the dataset when there are
+ multiple workers.
"""
def __init__(self,
@@ -354,11 +356,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
num_gpus=None,
num_gpus_per_worker=None,
cross_tower_ops=None,
- prefetch_on_device=None):
+ prefetch_on_device=None,
+ auto_shard_dataset=False):
super(MirroredStrategy, self).__init__()
self._cross_tower_ops = cross_tower_ops
self._prefetch_on_device = prefetch_on_device
+ self._auto_shard_dataset = auto_shard_dataset
# Rememeber num GPUs which might be needed by `configure` method.
if num_gpus is not None and num_gpus_per_worker is not None:
raise ValueError(
@@ -477,10 +481,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
if self._cluster_spec:
return values.MultiWorkerDataset(
partial(self._call_dataset_fn, dataset_fn), self._worker_device_map,
- self._prefetch_on_device)
+ self._prefetch_on_device, self._auto_shard_dataset)
else:
return values.PerDeviceDataset(
- self._call_dataset_fn(dataset_fn), self._devices,
+ self._call_dataset_fn(dataset_fn),
+ self._devices,
self._prefetch_on_device)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index c6894e9013..04c712ce1d 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -300,9 +300,15 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
- features = dist.distribute_dataset(
- lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
- ).make_one_shot_iterator().get_next()
+ ds = dist.distribute_dataset(
+ lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
+ if context.executing_eagerly():
+ iterator = ds.make_one_shot_iterator()
+ else:
+ iterator = ds.make_initializable_iterator()
+ self.evaluate([iterator.initializer])
+
+ features = iterator.get_next()
with dist.scope():
result = dist.call_for_each_tower(
@@ -1271,7 +1277,17 @@ class MirroredStrategyDefunTest(test.TestCase):
self.evaluate(device_result))
for defun in defuns:
- self.assertEqual(set(mock_model.variables), set(defun.variables))
+ # PolymorphicFunctions are specialized to the current device stack, so
+ # call_for_each has one trace per device. To check that the expected set
+ # of variables was accessed on each trace, we first retrieve each
+ # device-specific graph function.
+ per_device_graph_functions = dist.call_for_each_tower(
+ defun.get_concrete_function,
+ mock_model, *inputs, run_concurrently=False)
+ for device in devices:
+ graph_function = per_device_graph_functions.get(device=device)
+ self.assertEqual(set(mock_model.variables),
+ set(graph_function.graph.variables))
@test_util.run_in_graph_and_eager_modes()
def testVariableInDefun(self):
diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py
index 7644acedc9..17b7ab74f6 100644
--- a/tensorflow/contrib/distribute/python/monitor.py
+++ b/tensorflow/contrib/distribute/python/monitor.py
@@ -51,6 +51,7 @@ class Monitor(object):
else:
if session is None:
raise ValueError("Should provide a `session` in Graph mode.")
+ session.run(step_callable._iterator.initializer) # pylint: disable=protected-access
self._run_step = session.make_callable(step_callable())
session.run(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
index 18b4503eff..9f92ba7dde 100644
--- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py
+++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
@@ -36,9 +36,29 @@ from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.estimator import run_config
from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
+ASSIGNED_PORTS = set()
+lock = threading.Lock()
+
+
+def pick_unused_port():
+ """Returns an unused and unassigned local port."""
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+
+ global ASSIGNED_PORTS
+ with lock:
+ while True:
+ port = portpicker.pick_unused_port()
+ if port > 10000 and port not in ASSIGNED_PORTS:
+ ASSIGNED_PORTS.add(port)
+ logging.info('Using local port %r', port)
+ return port
+
+
def _create_cluster(num_workers,
num_ps,
has_chief=False,
@@ -49,8 +69,8 @@ def _create_cluster(num_workers,
"""Creates and starts local servers and returns the cluster_spec dict."""
if _portpicker_import_error:
raise _portpicker_import_error # pylint: disable=raising-bad-type
- worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
- ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
+ worker_ports = [pick_unused_port() for _ in range(num_workers)]
+ ps_ports = [pick_unused_port() for _ in range(num_ps)]
cluster_dict = {}
if num_workers > 0:
@@ -58,9 +78,9 @@ def _create_cluster(num_workers,
if num_ps > 0:
cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
if has_eval:
- cluster_dict['evaluator'] = ['localhost:%s' % portpicker.pick_unused_port()]
+ cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()]
if has_chief:
- cluster_dict['chief'] = ['localhost:%s' % portpicker.pick_unused_port()]
+ cluster_dict['chief'] = ['localhost:%s' % pick_unused_port()]
cs = server_lib.ClusterSpec(cluster_dict)
@@ -139,11 +159,36 @@ def create_in_process_cluster(num_workers,
num_workers,
num_ps=num_ps,
has_chief=has_chief,
+ has_eval=has_eval,
worker_config=worker_config,
ps_config=ps_config,
protocol='grpc')
+def create_cluster_spec(has_chief=False,
+ num_workers=1,
+ num_ps=0,
+ has_eval=False):
+ """Create a cluster spec with tasks with unused local ports."""
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+
+ cluster_spec = {}
+ if has_chief:
+ cluster_spec['chief'] = ['localhost:%s' % pick_unused_port()]
+ if num_workers:
+ cluster_spec['worker'] = [
+ 'localhost:%s' % pick_unused_port() for _ in range(num_workers)
+ ]
+ if num_ps:
+ cluster_spec['ps'] = [
+ 'localhost:%s' % pick_unused_port() for _ in range(num_ps)
+ ]
+ if has_eval:
+ cluster_spec['evaluator'] = ['localhost:%s' % pick_unused_port()]
+ return cluster_spec
+
+
class MultiWorkerTestBase(test.TestCase):
"""Base class for testing multi node strategy and dataset."""
diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
index 6e9ba37a19..3064433129 100644
--- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py
+++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
@@ -42,8 +42,11 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ ds = distribution.distribute_dataset(dataset_fn)
+ if context.executing_eagerly():
+ iterator = ds.make_one_shot_iterator()
+ else:
+ iterator = ds.make_initializable_iterator()
def run_step():
return control_flow_ops.group(distribution.unwrap(
@@ -52,6 +55,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
if not context.executing_eagerly():
with self.cached_session() as sess:
+ sess.run(iterator.initializer)
run_step = sess.make_callable(run_step())
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
deleted file mode 100644
index 1ff60c0762..0000000000
--- a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
+++ /dev/null
@@ -1,228 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Extension of prefetching_ops to support more than one device."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import warnings
-
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
-from tensorflow.contrib.data.python.ops import prefetching_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.data.util import nest as data_nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.eager import context
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
-from tensorflow.python.framework import ops
-from tensorflow.python.util import nest
-
-
-# pylint: disable=protected-access
-class _PrefetchToDeviceIterator(object):
- """A replacement for `tf.data.Iterator` that prefetches to another device.
-
- Args:
- input_dataset: The input dataset.
- one_shot: If true, we make a one shot iterator that's already initialized.
- devices: Devices on which to prefetch.
- buffer_size: Size of the prefetching buffer.
- shared_name: (Optional.) If non-empty, the returned iterator will be
- shared under the given name across multiple sessions that share the
- same devices (e.g. when using a remote server). Only used if one_shot
- is False.
-
- Returns:
- An Iterator type object.
- """
-
- def __init__(self,
- input_dataset,
- one_shot,
- devices,
- buffer_size,
- shared_name=None):
- self._input_dataset = input_dataset
- self._get_next_call_count = 0
- self._one_shot = one_shot
- if shared_name is None:
- shared_name = ""
- self._devices = devices
-
- if self._one_shot:
- self._input_iterator = input_dataset.make_one_shot_iterator()
- else:
- self._input_iterator = iterator_ops.Iterator.from_structure(
- self._input_dataset.output_types, self._input_dataset.output_shapes,
- shared_name, self._input_dataset.output_classes)
- input_iterator_handle = self._input_iterator.string_handle()
-
- @function.Defun(dtypes.string)
- def _prefetch_fn(handle):
- """Prefetches one element from `input_iterator`."""
- remote_iterator = iterator_ops.Iterator.from_string_handle(
- handle, self._input_iterator.output_types,
- self._input_iterator.output_shapes,
- self._input_iterator.output_classes)
- ret = remote_iterator.get_next()
- return nest.flatten(sparse.serialize_sparse_tensors(ret))
-
- target_device = gen_dataset_ops.iterator_get_device(
- self._input_iterator._iterator_resource)
- self._buffering_resources = []
- for device in nest.flatten(self._devices):
- with ops.device(device):
- buffer_resource_handle = prefetching_ops.function_buffering_resource(
- f=_prefetch_fn,
- output_types=data_nest.flatten(
- sparse.as_dense_types(self._input_dataset.output_types,
- self._input_dataset.output_classes)),
- target_device=target_device,
- string_arg=input_iterator_handle,
- buffer_size=buffer_size,
- shared_name=shared_name)
- self._buffering_resources.append(buffer_resource_handle)
-
- if not self._one_shot:
- reset_ops = []
- for buffer_resource in self._buffering_resources:
- reset_ops.append(
- prefetching_ops.function_buffering_resource_reset(buffer_resource))
- with ops.control_dependencies(reset_ops):
- self._initializer = self._input_iterator.make_initializer(
- self._input_dataset)
-
- def get_next(self, name=None):
- """See `tf.data.Iterator.get_next`."""
- self._get_next_call_count += 1
- if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
- warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
-
- flat_result = []
- # TODO(priyag): This will fail if the input size (typically number of
- # batches) is not divisible by number of devices.
- # How do we handle that more gracefully / let the user know?
- for buffer_resource in self._buffering_resources:
- flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
- buffer_resource,
- output_types=data_nest.flatten(sparse.as_dense_types(
- self.output_types, self.output_classes)), name=name)
-
- ret = sparse.deserialize_sparse_tensors(
- data_nest.pack_sequence_as(self.output_types, flat_ret),
- self.output_types, self.output_shapes, self.output_classes)
-
- for tensor, shape in zip(
- data_nest.flatten(ret), data_nest.flatten(self.output_shapes)):
- if isinstance(tensor, ops.Tensor):
- tensor.set_shape(shape)
- flat_result.append(ret)
-
- return nest.pack_sequence_as(self._devices, flat_result)
-
- @property
- def initializer(self):
- if self._one_shot:
- raise NotImplementedError("Can't initialize a one_shot_iterator")
- return self._initializer
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-# pylint: enable=protected-access
-
-
-class _PrefetchToDeviceDataset(dataset_ops.Dataset):
- """A `Dataset` whose iterator prefetches elements to other device(s)."""
-
- def __init__(self, input_dataset, devices, buffer_size):
- self._input_dataset = input_dataset
- self._devices = devices
- self._buffer_size = buffer_size if buffer_size is not None else 1
-
- def make_one_shot_iterator(self):
- return _PrefetchToDeviceIterator(
- self._input_dataset,
- one_shot=True,
- devices=self._devices,
- buffer_size=self._buffer_size)
-
- def make_initializable_iterator(self, shared_name=None):
- if context.executing_eagerly():
- raise RuntimeError(
- "make_initializable_iterator is not supported when eager "
- "execution is enabled.")
-
- return _PrefetchToDeviceIterator(
- self._input_dataset,
- one_shot=False,
- devices=self._devices,
- buffer_size=self._buffer_size,
- shared_name=shared_name)
-
- def _as_variant_tensor(self):
- # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
- # transformation methods is called.
- # TODO(mrry): Investigate support for chaining further transformations after
- # the prefetch, including GPU support.
- raise NotImplementedError("`prefetch_to_devices()` must be the last "
- "transformation in a dataset pipeline.")
-
- # TODO(priyag): Fix the output types, shapes and classes to match the result
- # of get_next (which has the additional nesting layer of devices now).
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
-
-def prefetch_to_devices(devices, buffer_size=None):
- """A transformation that prefetches dataset values to the given `devices`.
-
- NOTE: Although the transformation creates a `tf.data.Dataset`, the
- transformation must be the final `Dataset` in the input pipeline.
-
- Args:
- devices: A nested structure of devices on which to prefetch the data. It can
- be a single device name, or a tuple or list of device names.
- buffer_size: (Optional.) The number of elements to buffer on each device.
- Defaults to an automatically chosen value.
-
- Returns:
- A `Dataset` transformation function, which can be passed to
- `tf.data.Dataset.apply`.
- """
- def _apply_fn(dataset):
- return _PrefetchToDeviceDataset(dataset, devices, buffer_size)
-
- return _apply_fn
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
deleted file mode 100644
index 16799104e8..0000000000
--- a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
+++ /dev/null
@@ -1,90 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for prefetching_ops_v2."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.distribute.python import prefetching_ops_v2
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import test_util
-from tensorflow.python.platform import test
-
-
-class PrefetchingOpsV2Test(test.TestCase):
-
- def testPrefetchToOneDevice(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops_v2.prefetch_to_devices("/gpu:0"))
-
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchToTwoDevicesInAList(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
-
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- output = []
- # TODO(rohanj): Modify test to go till the end of the dataset when we
- # switch to MultiDeviceIterator.
- with self.cached_session() as sess:
- for _ in range(4):
- result = sess.run(next_element)
- self.assertEqual(2, len(result))
- output.extend(result)
- self.assertEquals(set(range(8)), set(output))
-
- def testPrefetchToTwoDevicesWithReinit(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
-
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- # TODO(rohanj): Modify test to go till the end of the dataset when we
- # switch to MultiDeviceIterator.
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for _ in range(4):
- sess.run(next_element)
- sess.run(iterator.initializer)
- for _ in range(4):
- sess.run(next_element)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py
index 1b5a4f64e5..23bf36184f 100644
--- a/tensorflow/contrib/distribute/python/step_fn.py
+++ b/tensorflow/contrib/distribute/python/step_fn.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
from tensorflow.python.training import optimizer as optimizer_lib
@@ -50,7 +51,11 @@ class StandardInputStep(Step):
def __init__(self, dataset_fn, distribution):
super(StandardInputStep, self).__init__(distribution)
self._distributed_input = distribution.distribute_dataset(dataset_fn)
- self._iterator = self._distributed_input.make_one_shot_iterator()
+ if context.executing_eagerly():
+ self._iterator = self._distributed_input.make_one_shot_iterator()
+ else:
+ # TODO(priyag): Expose initializer via some initializer property.
+ self._iterator = self._distributed_input.make_initializable_iterator()
class StandardSingleLossStep(StandardInputStep):
diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py
index f1ada49fa3..1ff9b9ceec 100644
--- a/tensorflow/contrib/distribute/python/step_fn_test.py
+++ b/tensorflow/contrib/distribute/python/step_fn_test.py
@@ -50,6 +50,7 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase):
run_step = single_loss_step
else:
with self.cached_session() as sess:
+ sess.run(single_loss_step._iterator.initializer)
run_step = sess.make_callable(single_loss_step())
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 6ba83976fc..a6762e5e87 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -158,7 +158,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
raise ValueError(
'TPU currently requires fully defined shapes. Either use '
'set_shape() on the input tensors or use '
- 'dataset.apply(map_and_batch(..., drop_remainder=True)).')
+ 'dataset.batch(..., drop_remainder=True).')
types = nest.flatten(iterator.output_types)
enqueue_ops = [
@@ -307,6 +307,22 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
def num_towers_per_host(self):
return self._tpu_metadata.num_of_cores_per_host
+ @property
+ def between_graph(self):
+ return False
+
+ @property
+ def should_init(self):
+ return True
+
+ @property
+ def should_checkpoint(self):
+ return True
+
+ @property
+ def should_save_summary(self):
+ return True
+
def get_host_cpu_device(self, host_id):
if self._tpu_cluster_resolver.get_master() in ('', 'local'):
return '/replica:0/task:0/device:CPU:0'
@@ -324,4 +340,3 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
cluster_spec = self._tpu_cluster_resolver.cluster_spec()
if cluster_spec:
session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
-
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index fafa6384a1..327775a729 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -26,7 +26,7 @@ import weakref
import six
from tensorflow.contrib.distribute.python import input_ops
-from tensorflow.contrib.distribute.python import prefetching_ops_v2
+from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
@@ -683,7 +683,7 @@ class PerDeviceDataIterator(object):
def get_next(self, name=None):
"""Scatter the input across devices."""
if self._prefetch_on_device:
- data_list = self._iterator.get_next(name=name)
+ data_list = self._iterator.get_next()
index = dict(zip(self._devices, data_list))
else:
batch = self._iterator.get_next(name=name)
@@ -703,21 +703,24 @@ class PerDeviceDataIterator(object):
class PerDeviceDataset(object):
"""Like `tf.data.Dataset` split devices, producing `PerDevice` data."""
- def __init__(self, dataset, devices, prefetch_on_device=None):
+ def __init__(
+ self,
+ dataset,
+ devices,
+ prefetch_on_device=None,
+ ):
self._devices = devices
# Default to using prefetching in graph mode, unless specified.
- # TODO(priyag): Enable prefetching in eager mode.
+ # TODO(rohanj): Enable prefetching in eager mode.
self._prefetch_on_device = prefetch_on_device
if self._prefetch_on_device is None:
self._prefetch_on_device = not context.executing_eagerly()
assert not (self._prefetch_on_device and context.executing_eagerly()), (
"Prefetching is only supported in graph mode currently")
- if self._prefetch_on_device:
- self._dataset = dataset.apply(
- prefetching_ops_v2.prefetch_to_devices(self._devices))
- else:
+ self._dataset = dataset
+ if not self._prefetch_on_device:
# TODO(priyag): If dropping remainder is not appropriate, find another
# approach to distributing the dataset when not possible to divide evenly.
# Possibly not an issue when we start using PartitionedDataset.
@@ -725,15 +728,33 @@ class PerDeviceDataset(object):
def make_one_shot_iterator(self):
"""Get a one time use iterator for the distributed PerDeviceDataset."""
+ # Graph mode prefetching with one shot iterator is disabled.
+ if not context.executing_eagerly():
+ raise ValueError("Cannot create a one shot iterator. Please use "
+ "`make_initializable_iterator()` instead.")
+ # Eager mode prefetching would error out in constructor. Only remaining
+ # cases are non-prefetching eager / graph mode. We delegate to
+ # PerDeviceDataIterator to handle them.
dataset_iterator = self._dataset.make_one_shot_iterator()
return PerDeviceDataIterator(
- dataset_iterator, self._devices, self._prefetch_on_device)
+ dataset_iterator, self._devices, prefetch_on_device=False)
def make_initializable_iterator(self):
"""Get an initializable iterator for the distributed PerDeviceDataset."""
- dataset_iterator = self._dataset.make_initializable_iterator()
+ # Eager mode generates already initialized iterators. Hence we cannot create
+ # an initializable iterator.
+ if context.executing_eagerly():
+ raise ValueError("Cannot create initializable iterator in Eager mode. "
+ "Please use `make_one_shot_iterator` instead.")
+ if self._prefetch_on_device:
+ dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ self._dataset, self._devices)
+ else:
+ dataset_iterator = self._dataset.make_initializable_iterator()
return PerDeviceDataIterator(
- dataset_iterator, self._devices, self._prefetch_on_device)
+ dataset_iterator,
+ self._devices,
+ prefetch_on_device=self._prefetch_on_device)
class MultiWorkerDataIterator(object):
@@ -793,7 +814,8 @@ class MultiWorkerDataset(object):
eager mode.
"""
- def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None):
+ def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None,
+ auto_shard=False):
"""Initialize the MultiWorkerDataset object.
Args:
@@ -801,6 +823,7 @@ class MultiWorkerDataset(object):
worker_device_map: a dict mapping from each worker to a list of devices
that belong to this worker.
prefetch_on_device: whether to prefetch to devices.
+ auto_shard: whether to auto-shard the dataset.
"""
self._worker_device_map = worker_device_map
self._datasets = {}
@@ -810,10 +833,13 @@ class MultiWorkerDataset(object):
six.iteritems(worker_device_map)):
with ops.device(worker):
worker_input = dataset_fn()
- worker_input = input_ops.auto_shard_dataset(
- worker_input, len(worker_device_map), i)
+ if auto_shard:
+ worker_input = input_ops.auto_shard_dataset(
+ worker_input, len(worker_device_map), i)
self._datasets[worker] = PerDeviceDataset(
- worker_input, worker_devices, prefetch_on_device=prefetch_on_device)
+ worker_input,
+ worker_devices,
+ prefetch_on_device=prefetch_on_device)
def make_one_shot_iterator(self):
iterators = {}
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index 15a85a28f5..002d61f46e 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -349,7 +349,11 @@ class PerDeviceDatasetTest(test.TestCase):
def _test_iterator_no_prefetch(self, devices, dataset, expected_values):
per_device_dataset = values.PerDeviceDataset(
dataset, devices, prefetch_on_device=False)
- iterator = per_device_dataset.make_one_shot_iterator()
+ if context.executing_eagerly():
+ iterator = per_device_dataset.make_one_shot_iterator()
+ else:
+ iterator = per_device_dataset.make_initializable_iterator()
+ self.evaluate([iterator.initializer])
for expected_value in expected_values:
next_element = iterator.get_next()
@@ -366,20 +370,14 @@ class PerDeviceDatasetTest(test.TestCase):
if not context.executing_eagerly():
per_device_dataset = values.PerDeviceDataset(
dataset, devices, prefetch_on_device=True)
- iterator = per_device_dataset.make_one_shot_iterator()
+ iterator = per_device_dataset.make_initializable_iterator()
+ self.evaluate([iterator.initializer])
- # With prefetching, we cannot guarantee which input ends up on which
- # device, so we verify that the complete set seen on all devices is
- # correct, and equal numbers are distributed to each device.
- combined_actual = []
- combined_expected = []
for expected_value in expected_values:
next_element = iterator.get_next()
- combined_actual.extend(self.evaluate([
- values.select_device(d, next_element) for d in devices]))
- combined_expected.extend(expected_value)
-
- self.assertEqual(set(combined_expected), set(combined_actual))
+ computed_value = self.evaluate(
+ [values.select_device(d, next_element) for d in devices])
+ self.assertEqual(expected_value, computed_value)
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 9aadc634da..3ff7da4f89 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -25,7 +25,6 @@ py_library(
"`tf.contrib.distributions` to `tfp.distributions`."),
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:clip_ops",
@@ -61,7 +60,6 @@ py_library(
":bijectors_py",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/learn",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:control_flow_ops",
@@ -706,8 +704,8 @@ cuda_py_test(
":bijectors_py",
":distributions_py",
"//third_party/py/numpy",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
@@ -722,8 +720,8 @@ cuda_py_test(
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/ops/linalg",
],
shard_count = 4,
tags = ["noasan"], # times out, http://b/78588814
@@ -739,8 +737,8 @@ cuda_py_test(
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
@@ -794,8 +792,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -831,8 +829,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -852,8 +850,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -871,8 +869,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -907,8 +905,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -926,10 +924,10 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
@@ -945,8 +943,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -964,8 +962,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -983,8 +981,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1002,8 +1000,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1021,8 +1019,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1040,8 +1038,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1075,8 +1073,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1126,8 +1124,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1161,8 +1159,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1180,8 +1178,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1201,8 +1199,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1221,8 +1219,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1240,8 +1238,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1259,8 +1257,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1278,8 +1276,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1297,8 +1295,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1316,8 +1314,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
index 8dad80aa64..c32ea9ade7 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
@@ -93,12 +93,12 @@ class SoftsignBijectorTest(test.TestCase):
bijector.inverse_log_det_jacobian(y, event_ndims=1)))
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softsign(validate_args=True)
assert_scalar_congruency(bijector, lower_x=-20., upper_x=20.)
def testBijectiveAndFinite(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softsign(validate_args=True)
x = np.linspace(-20., 20., 100).astype(np.float32)
y = np.linspace(-0.99, 0.99, 100).astype(np.float32)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
index f073f51a69..9b9b3ce2dd 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
@@ -212,7 +212,7 @@ class DistributionTest(test.TestCase):
def testStrWorksCorrectlyScalar(self):
normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1))
self.assertEqual(
- ("tf.distributions.Normal("
+ ("tfp.distributions.Normal("
"\"Normal/\", "
"batch_shape=(), "
"event_shape=(), "
@@ -221,7 +221,7 @@ class DistributionTest(test.TestCase):
chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly")
self.assertEqual(
- ("tf.distributions.Chi2("
+ ("tfp.distributions.Chi2("
"\"silly/\", " # What a silly name that is!
"batch_shape=(2,), "
"event_shape=(), "
@@ -230,7 +230,7 @@ class DistributionTest(test.TestCase):
exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32))
self.assertEqual(
- ("tf.distributions.Exponential(\"Exponential/\", "
+ ("tfp.distributions.Exponential(\"Exponential/\", "
# No batch shape.
"event_shape=(), "
"dtype=float32)"),
@@ -240,7 +240,7 @@ class DistributionTest(test.TestCase):
mvn_static = tfd.MultivariateNormalDiag(
loc=np.zeros([2, 2]), name="MVN")
self.assertEqual(
- ("tf.distributions.MultivariateNormalDiag("
+ ("tfp.distributions.MultivariateNormalDiag("
"\"MVN/\", "
"batch_shape=(2,), "
"event_shape=(2,), "
@@ -251,7 +251,7 @@ class DistributionTest(test.TestCase):
loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32),
name="MVN2")
self.assertEqual(
- ("tf.distributions.MultivariateNormalDiag("
+ ("tfp.distributions.MultivariateNormalDiag("
"\"MVN2/\", "
"batch_shape=(?,), " # Partially known.
"event_shape=(3,), "
@@ -261,7 +261,7 @@ class DistributionTest(test.TestCase):
def testReprWorksCorrectlyScalar(self):
normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1))
self.assertEqual(
- ("<tf.distributions.Normal"
+ ("<tfp.distributions.Normal"
" 'Normal/'"
" batch_shape=()"
" event_shape=()"
@@ -270,7 +270,7 @@ class DistributionTest(test.TestCase):
chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly")
self.assertEqual(
- ("<tf.distributions.Chi2"
+ ("<tfp.distributions.Chi2"
" 'silly/'" # What a silly name that is!
" batch_shape=(2,)"
" event_shape=()"
@@ -279,7 +279,7 @@ class DistributionTest(test.TestCase):
exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32))
self.assertEqual(
- ("<tf.distributions.Exponential"
+ ("<tfp.distributions.Exponential"
" 'Exponential/'"
" batch_shape=<unknown>"
" event_shape=()"
@@ -290,7 +290,7 @@ class DistributionTest(test.TestCase):
mvn_static = tfd.MultivariateNormalDiag(
loc=np.zeros([2, 2]), name="MVN")
self.assertEqual(
- ("<tf.distributions.MultivariateNormalDiag"
+ ("<tfp.distributions.MultivariateNormalDiag"
" 'MVN/'"
" batch_shape=(2,)"
" event_shape=(2,)"
@@ -301,7 +301,7 @@ class DistributionTest(test.TestCase):
loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32),
name="MVN2")
self.assertEqual(
- ("<tf.distributions.MultivariateNormalDiag"
+ ("<tfp.distributions.MultivariateNormalDiag"
" 'MVN2/'"
" batch_shape=(?,)" # Partially known.
" event_shape=(3,)"
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/moving_stats_test.py b/tensorflow/contrib/distributions/python/kernel_tests/moving_stats_test.py
index 3c988dad8a..be7c756bea 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/moving_stats_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/moving_stats_test.py
@@ -38,8 +38,8 @@ class MovingReduceMeanVarianceTest(test.TestCase):
true_stddev = np.array([[1.1, 0.5]])
with self.cached_session() as sess:
# Start "x" out with this mean.
- mean_var = variables.Variable(array_ops.zeros_like(true_mean))
- variance_var = variables.Variable(array_ops.ones_like(true_stddev))
+ mean_var = variables.VariableV1(array_ops.zeros_like(true_mean))
+ variance_var = variables.VariableV1(array_ops.ones_like(true_stddev))
x = random_ops.random_normal(shape, dtype=np.float64, seed=0)
x = true_stddev * x + true_mean
ema, emv = moving_stats.assign_moving_mean_variance(
@@ -115,7 +115,7 @@ class MovingLogExponentialMovingMeanExpTest(test.TestCase):
# Start "x" out with this mean.
x = random_ops.random_normal(shape, dtype=np.float64, seed=0)
x = true_stddev * x + true_mean
- log_mean_exp_var = variables.Variable(array_ops.zeros_like(true_mean))
+ log_mean_exp_var = variables.VariableV1(array_ops.zeros_like(true_mean))
variables.global_variables_initializer().run()
log_mean_exp = moving_stats.assign_log_moving_mean_exp(
log_mean_exp_var, x, decay=decay)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD
deleted file mode 100644
index 42ecea034d..0000000000
--- a/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD
+++ /dev/null
@@ -1,51 +0,0 @@
-# Description:
-# Internal testing utilities, e.g., computing the correct answer to
-# put in a unit test.
-
-licenses(["notice"]) # Apache 2.0
-
-py_library(
- name = "correlation_matrix_volumes_py",
- srcs = [
- "correlation_matrix_volumes_lib.py",
- ],
- deps = [
- "//tensorflow/contrib/distributions:distributions_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:math_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_binary(
- name = "correlation_matrix_volumes",
- srcs = [
- "correlation_matrix_volumes.py",
- ],
- deps = [
- ":correlation_matrix_volumes_py",
- ],
-)
-
-py_test(
- name = "correlation_matrix_volumes_test",
- size = "medium",
- srcs = ["correlation_matrix_volumes_test.py"],
- tags = [
- "no_pip",
- "optonly",
- ],
- deps = [
- ":correlation_matrix_volumes_py",
- # For statistical testing
- "//tensorflow/contrib/distributions:distributions_py",
- "//third_party/py/numpy",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework",
- ],
-)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py
deleted file mode 100644
index 2eab51cd30..0000000000
--- a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py
+++ /dev/null
@@ -1,98 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Executable to estimate the volume of various sets of correlation matrices.
-
-See correlation_matrix_volumes_lib.py for purpose and methodology.
-
-Invocation example:
-```
-python correlation_matrix_volumes.py --num_samples 1e7
-```
-
-This will compute 10,000,000-sample confidence intervals for the
-volumes of several sets of correlation matrices. Which sets, and the
-desired statistical significance, are hard-coded in this source file.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import pprint
-
-from absl import app
-from absl import flags
-
-from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr
-
-FLAGS = flags.FLAGS
-
-# Float to support giving the number of samples in scientific notation.
-# The production run used for the LKJ test used 1e7 samples.
-flags.DEFINE_float('num_samples', 1e4, 'Number of samples to use.')
-
-
-def ctv_debatched(det_bounds, dim, num_samples, error_rate=1e-6, seed=42):
- # This wrapper undoes the batching in compute_true_volumes, because
- # apparently several 5x5x9x1e7 Tensors of float32 can strain RAM.
- bounds = {}
- for db in det_bounds:
- bounds[db] = corr.compute_true_volumes(
- [db], dim, num_samples, error_rate=error_rate, seed=seed)[db]
- return bounds
-
-
-# The particular bounds in all three of these functions were chosen by
-# a somewhat arbitrary walk through an empirical tradeoff, for the
-# purpose of testing the LKJ distribution. Setting the determinant
-# bound lower
-# - Covers more of the testee's sample space, and
-# - Increases the probability that the rejection sampler will hit, thus
-# - Decreases the relative error (at a fixed sample count) in the
-# rejection-based volume estimate;
-# but also
-# - Increases the variance of the estimator used in the LKJ test.
-# This latter variance is also affected by the dimension and the
-# tested concentration parameter, and can be compensated for with more
-# compute (expensive) or a looser discrepancy limit (unsatisfying).
-# The values here are the projection of the points in that test design
-# space that ended up getting chosen.
-def compute_3x3_volumes(num_samples):
- det_bounds = [0.01, 0.25, 0.3, 0.35, 0.4, 0.45]
- return ctv_debatched(
- det_bounds, 3, num_samples, error_rate=5e-7, seed=46)
-
-
-def compute_4x4_volumes(num_samples):
- det_bounds = [0.01, 0.25, 0.3, 0.35, 0.4, 0.45]
- return ctv_debatched(
- det_bounds, 4, num_samples, error_rate=5e-7, seed=47)
-
-
-def compute_5x5_volumes(num_samples):
- det_bounds = [0.01, 0.2, 0.25, 0.3, 0.35, 0.4]
- return ctv_debatched(
- det_bounds, 5, num_samples, error_rate=5e-7, seed=48)
-
-
-def main(_):
- full_bounds = {}
- full_bounds[3] = compute_3x3_volumes(int(FLAGS.num_samples))
- full_bounds[4] = compute_4x4_volumes(int(FLAGS.num_samples))
- full_bounds[5] = compute_5x5_volumes(int(FLAGS.num_samples))
- pprint.pprint(full_bounds)
-
-if __name__ == '__main__':
- app.run(main)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py
deleted file mode 100644
index 455e71f00c..0000000000
--- a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py
+++ /dev/null
@@ -1,323 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Estimating the volume of the correlation matrices with bounded determinant.
-
-Why? Because lkj_test.py tests the sampler for the LKJ distribution
-by estimating the same volume another way.
-
-How? Rejection sampling. Or, more precisely, importance sampling,
-proposing from the uniform distribution on symmetric matrices with
-diagonal 1s and entries in [-1, 1]. Such a matrix is a correlation
-matrix if and only if it is also positive semi-definite.
-
-The samples can then be converted into a confidence interval on the
-volume in question by the [Clopper-Pearson
-method](https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval),
-also implemented here.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import importlib
-import sys
-
-import numpy as np
-
-from tensorflow.python.client import session
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops.distributions import uniform
-from tensorflow.python.ops.distributions import util
-from tensorflow.python.platform import tf_logging
-
-__all__ = [
- "correlation_matrix_volume_rejection_samples",
- "compute_true_volumes",
-]
-
-
-def try_import(name): # pylint: disable=invalid-name
- module = None
- try:
- module = importlib.import_module(name)
- except ImportError as e:
- tf_logging.warning("Could not import %s: %s" % (name, str(e)))
- return module
-
-optimize = try_import("scipy.optimize")
-stats = try_import("scipy.stats")
-
-
-def _psd_mask(x):
- """Computes whether each square matrix in the input is positive semi-definite.
-
- Args:
- x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`.
-
- Returns:
- mask: A floating-point `Tensor` of shape `[B1, ... Bn]`. Each
- scalar is 1 if the corresponding matrix was PSD, otherwise 0.
- """
- # Allegedly
- # https://scicomp.stackexchange.com/questions/12979/testing-if-a-matrix-is-positive-semi-definite
- # it is more efficient to test for positive semi-definiteness by
- # trying to compute the Cholesky decomposition -- the matrix is PSD
- # if you succeed and not PSD if you fail. However, TensorFlow's
- # Cholesky raises an exception if _any_ of the input matrices are
- # not PSD, from which I don't know how to extract _which ones_, so I
- # proceed by explicitly computing all the eigenvalues and checking
- # whether they are all positive or not.
- #
- # Also, as was discussed in the answer, it is somewhat dangerous to
- # treat SPD-ness as binary in floating-point arithmetic. Cholesky
- # factorization can complete and 'look' like everything is fine
- # (e.g., O(1) entries and a diagonal of all ones) but the matrix can
- # have an exponential condition number.
- eigenvalues, _ = linalg_ops.self_adjoint_eig(x)
- return math_ops.cast(
- math_ops.reduce_min(eigenvalues, axis=-1) >= 0, dtype=x.dtype)
-
-
-def _det_large_enough_mask(x, det_bounds):
- """Returns whether the input matches the given determinant limit.
-
- Args:
- x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`.
- det_bounds: A floating-point `Tensor` that must broadcast to shape
- `[B1, ..., Bn]`, giving the desired lower bound on the
- determinants in `x`.
-
- Returns:
- mask: A floating-point `Tensor` of shape [B1, ..., Bn]. Each
- scalar is 1 if the corresponding matrix had determinant above
- the corresponding bound, otherwise 0.
- """
- # For the curious: I wonder whether it is possible and desirable to
- # use a Cholesky decomposition-based algorithm for this, since the
- # only matrices whose determinant this code cares about will be PSD.
- # Didn't figure out how to code that in TensorFlow.
- #
- # Expert opinion is that it would be about twice as fast since
- # Cholesky is roughly half the cost of Gaussian Elimination with
- # Partial Pivoting. But this is less of an impact than the switch in
- # _psd_mask.
- return math_ops.cast(
- linalg_ops.matrix_determinant(x) > det_bounds, dtype=x.dtype)
-
-
-def _uniform_correlation_like_matrix(num_rows, batch_shape, dtype, seed):
- """Returns a uniformly random `Tensor` of "correlation-like" matrices.
-
- A "correlation-like" matrix is a symmetric square matrix with all entries
- between -1 and 1 (inclusive) and 1s on the main diagonal. Of these,
- the ones that are positive semi-definite are exactly the correlation
- matrices.
-
- Args:
- num_rows: Python `int` dimension of the correlation-like matrices.
- batch_shape: `Tensor` or Python `tuple` of `int` shape of the
- batch to return.
- dtype: `dtype` of the `Tensor` to return.
- seed: Random seed.
-
- Returns:
- matrices: A `Tensor` of shape `batch_shape + [num_rows, num_rows]`
- and dtype `dtype`. Each entry is in [-1, 1], and each matrix
- along the bottom two dimensions is symmetric and has 1s on the
- main diagonal.
- """
- num_entries = num_rows * (num_rows + 1) / 2
- ones = array_ops.ones(shape=[num_entries], dtype=dtype)
- # It seems wasteful to generate random values for the diagonal since
- # I am going to throw them away, but `fill_triangular` fills the
- # diagonal, so I probably need them.
- # It's not impossible that it would be more efficient to just fill
- # the whole matrix with random values instead of messing with
- # `fill_triangular`. Then would need to filter almost half out with
- # `matrix_band_part`.
- unifs = uniform.Uniform(-ones, ones).sample(batch_shape, seed=seed)
- tril = util.fill_triangular(unifs)
- symmetric = tril + array_ops.matrix_transpose(tril)
- diagonal_ones = array_ops.ones(
- shape=util.pad(batch_shape, axis=0, back=True, value=num_rows),
- dtype=dtype)
- return array_ops.matrix_set_diag(symmetric, diagonal_ones)
-
-
-def correlation_matrix_volume_rejection_samples(
- det_bounds, dim, sample_shape, dtype, seed):
- """Returns rejection samples from trying to get good correlation matrices.
-
- The proposal being rejected from is the uniform distribution on
- "correlation-like" matrices. We say a matrix is "correlation-like"
- if it is a symmetric square matrix with all entries between -1 and 1
- (inclusive) and 1s on the main diagonal. Of these, the ones that
- are positive semi-definite are exactly the correlation matrices.
-
- The rejection algorithm, then, is to sample a `Tensor` of
- `sample_shape` correlation-like matrices of dimensions `dim` by
- `dim`, and check each one for (i) being a correlation matrix (i.e.,
- PSD), and (ii) having determinant at least the corresponding entry
- of `det_bounds`.
-
- Args:
- det_bounds: A `Tensor` of lower bounds on the determinants of
- acceptable matrices. The shape must broadcast with `sample_shape`.
- dim: A Python `int` dimension of correlation matrices to sample.
- sample_shape: Python `tuple` of `int` shape of the samples to
- compute, excluding the two matrix dimensions.
- dtype: The `dtype` in which to do the computation.
- seed: Random seed.
-
- Returns:
- weights: A `Tensor` of shape `sample_shape`. Each entry is 0 if the
- corresponding matrix was not a correlation matrix, or had too
- small of a determinant. Otherwise, the entry is the
- multiplicative inverse of the density of proposing that matrix
- uniformly, i.e., the volume of the set of `dim` by `dim`
- correlation-like matrices.
- volume: The volume of the set of `dim` by `dim` correlation-like
- matrices.
- """
- with ops.name_scope("rejection_sampler"):
- rej_proposals = _uniform_correlation_like_matrix(
- dim, sample_shape, dtype, seed=seed)
- rej_proposal_volume = 2. ** (dim * (dim - 1) / 2.)
- # The density of proposing any given point is 1 / rej_proposal_volume;
- # The weight of that point should be scaled by
- # 1 / density = rej_proposal_volume.
- rej_weights = rej_proposal_volume * _psd_mask(
- rej_proposals) * _det_large_enough_mask(rej_proposals, det_bounds)
- return rej_weights, rej_proposal_volume
-
-
-def _clopper_pearson_confidence_interval(samples, error_rate):
- """Computes a confidence interval for the mean of the given 1-D distribution.
-
- Assumes (and checks) that the given distribution is Bernoulli, i.e.,
- takes only two values. This licenses using the CDF of the binomial
- distribution for the confidence, which is tighter (for extreme
- probabilities) than the DKWM inequality. The method is known as the
- [Clopper-Pearson method]
- (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval).
-
- Assumes:
-
- - The given samples were drawn iid from the distribution of interest.
-
- - The given distribution is a Bernoulli, i.e., supported only on
- low and high.
-
- Guarantees:
-
- - The probability (over the randomness of drawing the given sample)
- that the true mean is outside the returned interval is no more
- than the given error_rate.
-
- Args:
- samples: `np.ndarray` of samples drawn iid from the distribution
- of interest.
- error_rate: Python `float` admissible rate of mistakes.
-
- Returns:
- low: Lower bound of confidence interval.
- high: Upper bound of confidence interval.
-
- Raises:
- ValueError: If `samples` has rank other than 1 (batch semantics
- are not implemented), or if `samples` contains values other than
- `low` or `high` (as that makes the distribution not Bernoulli).
- """
- # TODO(b/78025336) Migrate this confidence interval function
- # to statistical_testing.py. In order to do that
- # - Get the binomial CDF from the Binomial distribution
- # - Implement scalar root finding in TF. Batch bisection search
- # shouldn't be too hard, and is definitely good enough for this
- # problem. Batching the Brent algorithm (from scipy) that is used
- # here may be more involved, but may also not be necessary---it's
- # only used here because scipy made it convenient. In particular,
- # robustness is more important than speed here, which may make
- # bisection search actively better.
- # - The rest is just a matter of rewriting in the appropriate style.
- if optimize is None or stats is None:
- raise ValueError(
- "Scipy is required for computing Clopper-Pearson confidence intervals")
- if len(samples.shape) != 1:
- raise ValueError("Batch semantics not implemented")
- n = len(samples)
- low = np.amin(samples)
- high = np.amax(samples)
- successes = np.count_nonzero(samples - low)
- failures = np.count_nonzero(samples - high)
- if successes + failures != n:
- uniques = np.unique(samples)
- msg = ("Purportedly Bernoulli distribution had distinct samples"
- " {}, {}, and {}".format(uniques[0], uniques[1], uniques[2]))
- raise ValueError(msg)
- def p_small_enough(p):
- prob = stats.binom.logcdf(successes, n, p)
- return prob - np.log(error_rate / 2.)
- def p_big_enough(p):
- prob = stats.binom.logsf(successes, n, p)
- return prob - np.log(error_rate / 2.)
- high_p = optimize.brentq(
- p_small_enough, float(successes) / n, 1., rtol=1e-9)
- low_p = optimize.brentq(
- p_big_enough, 0., float(successes) / n, rtol=1e-9)
- low_interval = low + (high - low) * low_p
- high_interval = low + (high - low) * high_p
- return (low_interval, high_interval)
-
-
-def compute_true_volumes(
- det_bounds, dim, num_samples, error_rate=1e-6, seed=42):
- """Returns confidence intervals for the desired correlation matrix volumes.
-
- The confidence intervals are computed by the [Clopper-Pearson method]
- (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval).
-
- Args:
- det_bounds: A rank-1 numpy array of lower bounds on the
- determinants of acceptable matrices. Entries must be unique.
- dim: A Python `int` dimension of correlation matrices to sample.
- num_samples: The number of samples to draw.
- error_rate: The statistical significance of the returned
- confidence intervals. The significance is broadcast: Each
- returned interval separately may be incorrect with probability
- (under the sample of correlation-like matrices drawn internally)
- at most `error_rate`.
- seed: Random seed.
-
- Returns:
- bounds: A Python `dict` mapping each determinant bound to the low, high
- tuple giving the confidence interval.
- """
- bounds = {}
- with session.Session() as sess:
- rej_weights, _ = correlation_matrix_volume_rejection_samples(
- det_bounds, dim, [num_samples, len(det_bounds)], np.float32, seed=seed)
- rej_weights = sess.run(rej_weights)
- for rw, det in zip(np.rollaxis(rej_weights, 1), det_bounds):
- template = ("Estimating volume of {}x{} correlation "
- "matrices with determinant >= {}.")
- print(template.format(dim, dim, det))
- sys.stdout.flush()
- bounds[det] = _clopper_pearson_confidence_interval(
- rw, error_rate=error_rate)
- return bounds
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py
deleted file mode 100644
index 8f99300e63..0000000000
--- a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py
+++ /dev/null
@@ -1,150 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for correlation_matrix_volumes_lib.py."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr
-from tensorflow.contrib.distributions.python.ops import statistical_testing as st
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.platform import test
-
-
-# NxN correlation matrices are determined by the N*(N-1)/2
-# lower-triangular entries. In addition to being between -1 and 1,
-# they must also obey the constraint that the determinant of the
-# resulting symmetric matrix is non-negative. In 2x2, we can even
-# analytically compute the volume when the determinant is bounded to >
-# epsilon, as that boils down to the one lower-triangular entry being
-# less than 1 - epsilon in absolute value.
-def two_by_two_volume(det_bound):
- return 2 * np.sqrt(1.0 - det_bound)
-
-
-# The post
-# https://psychometroscar.com/the-volume-of-a-3-x-3-correlation-matrix/
-# derives (with elementary calculus) that the volume (with respect to
-# Lebesgue^3 measure) of the set of 3x3 correlation matrices is
-# pi^2/2. The same result is also obtained by [1].
-def three_by_three_volume():
- return np.pi**2 / 2.
-
-
-# The volume of the unconstrained set of correlation matrices is also
-# the normalization constant of the LKJ distribution from [2]. As
-# part of defining the distribution, that reference a derives general
-# formula for this volume for all dimensions. A TensorFlow
-# computation thereof gave the below result for 4x4:
-def four_by_four_volume():
- # This constant computed as math_ops.exp(lkj.log_norm_const(4, [1.0]))
- return 11.6973076
-
-# [1] Rousseeuw, P. J., & Molenberghs, G. (1994). "The shape of
-# correlation matrices." The American Statistician, 48(4), 276-279.
-
-# [2] Daniel Lewandowski, Dorota Kurowicka, and Harry Joe, "Generating
-# random correlation matrices based on vines and extended onion
-# method," Journal of Multivariate Analysis 100 (2009), pp 1989-2001.
-
-
-class CorrelationMatrixVolumesTest(test.TestCase):
-
- def testRejection2D(self):
- num_samples = int(1e5) # Chosen for a small min detectable discrepancy
- det_bounds = np.array(
- [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32)
- exact_volumes = two_by_two_volume(det_bounds)
- (rej_weights,
- rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples(
- det_bounds, 2, [num_samples, 9], dtype=np.float32, seed=43)
- # shape of rej_weights: [num_samples, 9, 2, 2]
- chk1 = st.assert_true_mean_equal_by_dkwm(
- rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes,
- false_fail_rate=1e-6)
- chk2 = check_ops.assert_less(
- st.min_discrepancy_of_true_means_detectable_by_dkwm(
- num_samples, low=0., high=rej_proposal_volume,
- # Correct the false fail rate due to different broadcasting
- false_fail_rate=1.1e-7, false_pass_rate=1e-6),
- 0.036)
- with ops.control_dependencies([chk1, chk2]):
- rej_weights = array_ops.identity(rej_weights)
- self.evaluate(rej_weights)
-
- def testRejection3D(self):
- num_samples = int(1e5) # Chosen for a small min detectable discrepancy
- det_bounds = np.array([0.0], dtype=np.float32)
- exact_volumes = np.array([three_by_three_volume()], dtype=np.float32)
- (rej_weights,
- rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples(
- det_bounds, 3, [num_samples, 1], dtype=np.float32, seed=44)
- # shape of rej_weights: [num_samples, 1, 3, 3]
- chk1 = st.assert_true_mean_equal_by_dkwm(
- rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes,
- false_fail_rate=1e-6)
- chk2 = check_ops.assert_less(
- st.min_discrepancy_of_true_means_detectable_by_dkwm(
- num_samples, low=0., high=rej_proposal_volume,
- false_fail_rate=1e-6, false_pass_rate=1e-6),
- # Going for about a 3% relative error
- 0.15)
- with ops.control_dependencies([chk1, chk2]):
- rej_weights = array_ops.identity(rej_weights)
- self.evaluate(rej_weights)
-
- def testRejection4D(self):
- num_samples = int(1e5) # Chosen for a small min detectable discrepancy
- det_bounds = np.array([0.0], dtype=np.float32)
- exact_volumes = [four_by_four_volume()]
- (rej_weights,
- rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples(
- det_bounds, 4, [num_samples, 1], dtype=np.float32, seed=45)
- # shape of rej_weights: [num_samples, 1, 4, 4]
- chk1 = st.assert_true_mean_equal_by_dkwm(
- rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes,
- false_fail_rate=1e-6)
- chk2 = check_ops.assert_less(
- st.min_discrepancy_of_true_means_detectable_by_dkwm(
- num_samples, low=0., high=rej_proposal_volume,
- false_fail_rate=1e-6, false_pass_rate=1e-6),
- # Going for about a 10% relative error
- 1.1)
- with ops.control_dependencies([chk1, chk2]):
- rej_weights = array_ops.identity(rej_weights)
- self.evaluate(rej_weights)
-
- def testVolumeEstimation2D(self):
- # Test that the confidence intervals produced by
- # corr.compte_true_volumes are sound, in the sense of containing
- # the exact volume.
- num_samples = int(1e5) # Chosen by symmetry with testRejection2D
- det_bounds = np.array(
- [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32)
- volume_bounds = corr.compute_true_volumes(
- det_bounds, 2, num_samples, error_rate=1e-6, seed=47)
- exact_volumes = two_by_two_volume(det_bounds)
- for det, volume in zip(det_bounds, exact_volumes):
- computed_low, computed_high = volume_bounds[det]
- self.assertLess(computed_low, volume)
- self.assertGreater(computed_high, volume)
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py
index bb9b8043b2..3ba1c3a665 100644
--- a/tensorflow/contrib/distributions/python/ops/autoregressive.py
+++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py
@@ -65,13 +65,14 @@ class Autoregressive(distribution_lib.Distribution):
```
where the ellipses (`...`) represent `n-2` composed calls to `fn`, `fn`
- constructs a `tf.distributions.Distribution`-like instance, and `x0` is a
+ constructs a `tfp.distributions.Distribution`-like instance, and `x0` is a
fixed initializing `Tensor`.
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
def normal_fn(self, event_size):
n = event_size * (event_size + 1) / 2
@@ -127,7 +128,7 @@ class Autoregressive(distribution_lib.Distribution):
Args:
distribution_fn: Python `callable` which constructs a
- `tf.distributions.Distribution`-like instance from a `Tensor` (e.g.,
+ `tfp.distributions.Distribution`-like instance from a `Tensor` (e.g.,
`sample0`). The function must respect the "autoregressive property",
i.e., there exists a permutation of event such that each coordinate is a
diffeomorphic function of on preceding coordinates.
diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
index 519077bc9a..612376efb7 100644
--- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py
+++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
@@ -45,7 +45,8 @@ class BatchReshape(distribution_lib.Distribution):
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
dtype = np.float32
dims = 2
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
index 296e66f2b2..3b3d8ee6f2 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
@@ -61,8 +61,8 @@ class MaskedAutoregressiveFlow(bijector.Bijector):
`shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves
this property by zeroing out weights in its `masked_dense` layers.
- In the `tf.distributions` framework, a "normalizing flow" is implemented as a
- `tf.contrib.distributions.bijectors.Bijector`. The `forward` "autoregression"
+ In the `tfp` framework, a "normalizing flow" is implemented as a
+ `tfp.bijectors.Bijector`. The `forward` "autoregression"
is implemented using a `tf.while_loop` and a deep neural network (DNN) with
masked weights such that the autoregressive property is automatically met in
the `inverse`.
@@ -126,8 +126,9 @@ class MaskedAutoregressiveFlow(bijector.Bijector):
#### Examples
```python
- tfd = tf.contrib.distributions
- tfb = tfd.bijectors
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+ tfb = tfp.bijectors
dims = 5
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py
index f182a1adcb..178c3c94bf 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py
@@ -41,9 +41,10 @@ class Permute(bijector.Bijector):
"""Permutes the rightmost dimension of a `Tensor`.
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfb = tfp.bijectors
- reverse = tfd.bijectors.Permute(permutation=[2, 1, 0])
+ reverse = tfb.Permute(permutation=[2, 1, 0])
reverse.forward([-1., 0., 1.])
# ==> [1., 0., -1]
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
index 773ae24461..0bcb08cdea 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
@@ -90,8 +90,9 @@ class RealNVP(bijector.Bijector):
#### Example Use
```python
- tfd = tf.contrib.distributions
- tfb = tfd.bijectors
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+ tfb = tfp.bijectors
# A common choice for a normalizing flow is to use a Gaussian for the base
# distribution. (However, any continuous distribution would work.) E.g.,
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py
index c8282229a3..71ac29038f 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py
@@ -80,9 +80,10 @@ class Reshape(bijector.Bijector):
Example usage:
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfb = tfp.bijectors
- r = tfd.bijectors.Reshape(event_shape_out=[1, -1])
+ r = tfb.Reshape(event_shape_out=[1, -1])
r.forward([3., 4.]) # shape [2]
# ==> [[3., 4.]] # shape [1, 2]
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py
index 6fbe866578..0a6d690b65 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py
@@ -42,7 +42,10 @@ class ScaleTriL(chain.Chain):
#### Examples
```python
- tfb = tf.contrib.distributions.bijectors
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+ tfb = tfp.bijectors
+
b = tfb.ScaleTriL(
diag_bijector=tfb.Exp(),
diag_shift=None)
diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py
index cb5223b055..c461833b9a 100644
--- a/tensorflow/contrib/distributions/python/ops/cauchy.py
+++ b/tensorflow/contrib/distributions/python/ops/cauchy.py
@@ -63,7 +63,8 @@ class Cauchy(distribution.Distribution):
Examples of initialization of one or a batch of distributions.
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Define a single scalar Cauchy distribution.
dist = tfd.Cauchy(loc=0., scale=3.)
diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py
index affc64a14f..507c5d3679 100644
--- a/tensorflow/contrib/distributions/python/ops/deterministic.py
+++ b/tensorflow/contrib/distributions/python/ops/deterministic.py
@@ -198,8 +198,11 @@ class Deterministic(_BaseDeterministic):
#### Examples
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Initialize a single Deterministic supported at zero.
- constant = tf.contrib.distributions.Deterministic(0.)
+ constant = tfd.Deterministic(0.)
constant.prob(0.)
==> 1.
constant.prob(2.)
@@ -208,7 +211,7 @@ class Deterministic(_BaseDeterministic):
# Initialize a [2, 2] batch of scalar constants.
loc = [[0., 1.], [2., 3.]]
x = [[0., 1.1], [1.99, 3.]]
- constant = tf.contrib.distributions.Deterministic(loc)
+ constant = tfd.Deterministic(loc)
constant.prob(x)
==> [[1., 0.], [0., 1.]]
```
@@ -310,7 +313,8 @@ class VectorDeterministic(_BaseDeterministic):
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single VectorDeterministic supported at [0., 2.] in R^2.
constant = tfd.Deterministic([0., 2.])
diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py
index acdea4d61d..4b50df5b48 100644
--- a/tensorflow/contrib/distributions/python/ops/gumbel.py
+++ b/tensorflow/contrib/distributions/python/ops/gumbel.py
@@ -63,7 +63,8 @@ class _Gumbel(distribution.Distribution):
Examples of initialization of one or a batch of distributions.
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Define a single scalar Gumbel distribution.
dist = tfd.Gumbel(loc=0., scale=3.)
diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py
index b02c403106..f121637086 100644
--- a/tensorflow/contrib/distributions/python/ops/half_normal.py
+++ b/tensorflow/contrib/distributions/python/ops/half_normal.py
@@ -66,15 +66,18 @@ class HalfNormal(distribution.Distribution):
Examples of initialization of one or a batch of distributions.
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Define a single scalar HalfNormal distribution.
- dist = tf.contrib.distributions.HalfNormal(scale=3.0)
+ dist = tfd.HalfNormal(scale=3.0)
# Evaluate the cdf at 1, returning a scalar.
dist.cdf(1.)
# Define a batch of two scalar valued HalfNormals.
# The first has scale 11.0, the second 22.0
- dist = tf.contrib.distributions.HalfNormal(scale=[11.0, 22.0])
+ dist = tfd.HalfNormal(scale=[11.0, 22.0])
# Evaluate the pdf of the first distribution on 1.0, and the second on 1.5,
# returning a length two tensor.
diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py
index 0672702b96..e1cfff3c66 100644
--- a/tensorflow/contrib/distributions/python/ops/independent.py
+++ b/tensorflow/contrib/distributions/python/ops/independent.py
@@ -70,7 +70,8 @@ class Independent(distribution_lib.Distribution):
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Make independent distribution from a 2-batch Normal.
ind = tfd.Independent(
diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
index 70d050d7a6..452628257e 100644
--- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
+++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
@@ -89,7 +89,9 @@ class InverseGamma(distribution.Distribution):
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
dist = tfd.InverseGamma(concentration=3.0, rate=2.0)
dist2 = tfd.InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
```
diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py
index 02e3bad51e..21c9b5a354 100644
--- a/tensorflow/contrib/distributions/python/ops/logistic.py
+++ b/tensorflow/contrib/distributions/python/ops/logistic.py
@@ -61,7 +61,8 @@ class Logistic(distribution.Distribution):
Examples of initialization of one or a batch of distributions.
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Define a single scalar Logistic distribution.
dist = tfd.Logistic(loc=0., scale=3.)
diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py
index 3b7114ef06..52b67f2c54 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture.py
@@ -50,7 +50,9 @@ class Mixture(distribution.Distribution):
```python
# Create a mixture of two Gaussians:
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
mix = 0.3
bimix_gauss = tfd.Mixture(
cat=tfd.Categorical(probs=[mix, 1.-mix]),
diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
index 8ffee940d0..f4d394ff29 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
@@ -44,7 +44,8 @@ class MixtureSameFamily(distribution.Distribution):
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
### Create a mixture of two scalar Gaussians:
@@ -113,12 +114,12 @@ class MixtureSameFamily(distribution.Distribution):
"""Construct a `MixtureSameFamily` distribution.
Args:
- mixture_distribution: `tf.distributions.Categorical`-like instance.
+ mixture_distribution: `tfp.distributions.Categorical`-like instance.
Manages the probability of selecting components. The number of
categories must match the rightmost batch dimension of the
`components_distribution`. Must have either scalar `batch_shape` or
`batch_shape` matching `components_distribution.batch_shape[:-1]`.
- components_distribution: `tf.distributions.Distribution`-like instance.
+ components_distribution: `tfp.distributions.Distribution`-like instance.
Right-most batch dimension indexes components.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
index cd0c282ba6..0b5b76be92 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
@@ -85,7 +85,8 @@ class MultivariateNormalDiag(
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 2-variate Gaussian.
mvn = tfd.MultivariateNormalDiag(
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
index 74d9d04fc7..80546083d3 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
@@ -87,7 +87,8 @@ class MultivariateNormalDiagPlusLowRank(
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 3-variate Gaussian with covariance `cov = S @ S.T`,
# `S = diag(d) + U @ diag(m) @ U.T`. The perturbation, `U @ diag(m) @ U.T`, is
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
index dbc4c1b3dc..bcb4937980 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
@@ -73,7 +73,8 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL):
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 3-variate Gaussian.
mu = [1., 2, 3]
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
index efe5a6d0d9..8fdc99824b 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
@@ -91,7 +91,8 @@ class MultivariateNormalLinearOperator(
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 3-variate Gaussian.
mu = [1., 2, 3]
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
index c6a23e4336..c21f70fc3b 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
@@ -77,13 +77,14 @@ class MultivariateNormalTriL(
```
Trainable (batch) lower-triangular matrices can be created with
- `tf.contrib.distributions.matrix_diag_transform()` and/or
- `tf.contrib.distributions.fill_triangular()`
+ `tfp.distributions.matrix_diag_transform()` and/or
+ `tfp.distributions.fill_triangular()`
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 3-variate Gaussian.
mu = [1., 2, 3]
diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
index 7a7ad1be35..85683e3233 100644
--- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
+++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
@@ -220,7 +220,8 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Create two batches of PoissonLogNormalQuadratureCompounds, one with
# prior `loc = 0.` and another with `loc = 1.` In both cases `scale = 1.`
diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
index 18a0f754e6..134658deab 100644
--- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
@@ -196,8 +196,9 @@ class QuantizedDistribution(distributions.Distribution):
parameter determining the unnormalized probability of that component.
```python
- tfd = tf.contrib.distributions
- tfb = tfd.bijectors
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+ tfb = tfp.bijectors
net = wavenet(inputs)
loc, unconstrained_scale, logits = tf.split(net,
diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
index a9d0fb4ccf..4b520b912e 100644
--- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
+++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
@@ -124,7 +124,7 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution):
tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight)
distribution: `tf.Distribution`-like instance. Distribution that is
transformed to produce this distribution.
- Default is `tf.distributions.Normal(0., 1.)`.
+ Default is `tfp.distributions.Normal(0., 1.)`.
Must be a scalar-batch, scalar-event distribution. Typically
`distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
a function of non-trainable parameters. WARNING: If you backprop through
diff --git a/tensorflow/contrib/distributions/python/ops/statistical_testing.py b/tensorflow/contrib/distributions/python/ops/statistical_testing.py
index c25e8c51d7..af22f4843a 100644
--- a/tensorflow/contrib/distributions/python/ops/statistical_testing.py
+++ b/tensorflow/contrib/distributions/python/ops/statistical_testing.py
@@ -30,27 +30,27 @@ is some expected constant. Suppose the support of P is the interval
`[0, 1]`. Then you might do this:
```python
-tfd = tf.contrib.distributions
-
-expected_mean = ...
-num_samples = 5000
-samples = ... draw 5000 samples from P
-
-# Check that the mean looks right
-check1 = tfd.assert_true_mean_equal_by_dkwm(
- samples, low=0., high=1., expected=expected_mean,
- false_fail_rate=1e-6)
-
-# Check that the difference in means detectable with 5000 samples is
-# small enough
-check2 = tf.assert_less(
- tfd.min_discrepancy_of_true_means_detectable_by_dkwm(
- num_samples, low=0., high=1.0,
- false_fail_rate=1e-6, false_pass_rate=1e-6),
- 0.01)
-
-# Be sure to execute both assertion ops
-sess.run([check1, check2])
+ from tensorflow_probability.python.distributions.internal import statistical_testing
+
+ expected_mean = ...
+ num_samples = 5000
+ samples = ... draw 5000 samples from P
+
+ # Check that the mean looks right
+ check1 = statistical_testing.assert_true_mean_equal_by_dkwm(
+ samples, low=0., high=1., expected=expected_mean,
+ false_fail_rate=1e-6)
+
+ # Check that the difference in means detectable with 5000 samples is
+ # small enough
+ check2 = tf.assert_less(
+ statistical_testing.min_discrepancy_of_true_means_detectable_by_dkwm(
+ num_samples, low=0., high=1.0,
+ false_fail_rate=1e-6, false_pass_rate=1e-6),
+ 0.01)
+
+ # Be sure to execute both assertion ops
+ sess.run([check1, check2])
```
The second assertion is an instance of experiment design. It's a
diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
index 3c8aae2797..a3d178357b 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
@@ -300,7 +300,8 @@ class VectorDiffeomixture(distribution_lib.Distribution):
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Create two batches of VectorDiffeomixtures, one with mix_loc=[0.],
# another with mix_loc=[1]. In both cases, `K=2` and the affine
diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
index 73356a3625..36cbd71f8b 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
@@ -90,7 +90,8 @@ class VectorExponentialDiag(
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 2-variate VectorExponential, supported on
# {(x, y) in R^2 : x > 0, y > 0}.
diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
index 9a47b48557..fd5bf9ecc7 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
@@ -108,7 +108,8 @@ class VectorExponentialLinearOperator(
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 2-variate VectorExponential, supported on
# {(x, y) in R^2 : x > 0, y > 0}.
diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
index e68ddc569c..8cd4e128c7 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
@@ -102,7 +102,8 @@ class VectorLaplaceDiag(
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 2-variate VectorLaplace.
vla = tfd.VectorLaplaceDiag(
diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
index 3923161a33..67d2ccd28d 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
@@ -110,7 +110,8 @@ class VectorLaplaceLinearOperator(
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 3-variate VectorLaplace with some desired covariance.
mu = [1., 2, 3]
diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
index 49ffff24ca..da57d0cb55 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
@@ -152,7 +152,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution):
broadcastable with `event_shape`.
distribution: `tf.Distribution`-like instance. Distribution from which `k`
iid samples are used as input to transformation `F`. Default is
- `tf.distributions.Normal(loc=0., scale=1.)`.
+ `tfp.distributions.Normal(loc=0., scale=1.)`.
Must be a scalar-batch, scalar-event distribution. Typically
`distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
a function of non-trainable parameters. WARNING: If you backprop through
diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
index f289b39e51..bad91a0844 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
@@ -92,7 +92,8 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
Extra leading dimensions, if provided, allow for batches.
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 3-variate vector Student's t-distribution.
mu = [1., 2, 3]
diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py
index 49b9de0ab5..ee2fc58864 100644
--- a/tensorflow/contrib/distributions/python/ops/wishart.py
+++ b/tensorflow/contrib/distributions/python/ops/wishart.py
@@ -480,11 +480,14 @@ class WishartCholesky(_WishartLinearOperator):
#### Examples
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Initialize a single 3x3 Wishart with Cholesky factored scale matrix and 5
# degrees-of-freedom.(*)
df = 5
chol_scale = tf.cholesky(...) # Shape is [3, 3].
- dist = tf.contrib.distributions.WishartCholesky(df=df, scale=chol_scale)
+ dist = tfd.WishartCholesky(df=df, scale=chol_scale)
# Evaluate this on an observation in R^3, returning a scalar.
x = ... # A 3x3 positive definite matrix.
@@ -498,14 +501,14 @@ class WishartCholesky(_WishartLinearOperator):
# Initialize two 3x3 Wisharts with Cholesky factored scale matrices.
df = [5, 4]
chol_scale = tf.cholesky(...) # Shape is [2, 3, 3].
- dist = tf.contrib.distributions.WishartCholesky(df=df, scale=chol_scale)
+ dist = tfd.WishartCholesky(df=df, scale=chol_scale)
# Evaluate this on four observations.
x = [[x0, x1], [x2, x3]] # Shape is [2, 2, 3, 3].
dist.prob(x) # Shape is [2, 2].
# (*) - To efficiently create a trainable covariance matrix, see the example
- # in tf.contrib.distributions.matrix_diag_transform.
+ # in tfp.distributions.matrix_diag_transform.
```
"""
@@ -604,11 +607,14 @@ class WishartFull(_WishartLinearOperator):
#### Examples
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Initialize a single 3x3 Wishart with Full factored scale matrix and 5
# degrees-of-freedom.(*)
df = 5
scale = ... # Shape is [3, 3]; positive definite.
- dist = tf.contrib.distributions.WishartFull(df=df, scale=scale)
+ dist = tfd.WishartFull(df=df, scale=scale)
# Evaluate this on an observation in R^3, returning a scalar.
x = ... # A 3x3 positive definite matrix.
@@ -622,14 +628,14 @@ class WishartFull(_WishartLinearOperator):
# Initialize two 3x3 Wisharts with Full factored scale matrices.
df = [5, 4]
scale = ... # Shape is [2, 3, 3].
- dist = tf.contrib.distributions.WishartFull(df=df, scale=scale)
+ dist = tfd.WishartFull(df=df, scale=scale)
# Evaluate this on four observations.
x = [[x0, x1], [x2, x3]] # Shape is [2, 2, 3, 3]; xi is positive definite.
dist.prob(x) # Shape is [2, 2].
# (*) - To efficiently create a trainable covariance matrix, see the example
- # in tf.contrib.distributions.matrix_diag_transform.
+ # in tfd.matrix_diag_transform.
```
"""
diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md
index 86d203452e..4bd2769e87 100644
--- a/tensorflow/contrib/eager/README.md
+++ b/tensorflow/contrib/eager/README.md
@@ -44,7 +44,6 @@ Installation instructions at https://www.tensorflow.org/install/
For an introduction to eager execution in TensorFlow, see:
-- [User Guide](https://www.tensorflow.org/guide/eager) ([source](../../docs_src/guide/eager.md))
-- Notebook: [Basic Usage](python/examples/notebooks/1_basics.ipynb)
-- Notebook: [Gradients](python/examples/notebooks/2_gradients.ipynb)
-- Notebook: [Importing Data](python/examples/notebooks/3_datasets.ipynb)
+- [User Guide](https://www.tensorflow.org/guide/eager) ([source](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/index.md))
+- Notebook: [Basic Usage](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/eager_basics.ipynb)
+- Notebook: [Automatic differentiation and gradient tape](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb)
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD
index 84517b57c7..33a1d572a2 100644
--- a/tensorflow/contrib/eager/python/BUILD
+++ b/tensorflow/contrib/eager/python/BUILD
@@ -14,6 +14,7 @@ py_library(
":datasets",
":metrics",
":network",
+ ":parameter_server",
":remote",
":saver",
"//tensorflow/python:framework_ops",
@@ -97,6 +98,18 @@ py_library(
],
)
+py_library(
+ name = "parameter_server",
+ srcs = ["parameter_server.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:framework",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
cuda_py_test(
name = "saver_test",
srcs = ["saver_test.py"],
@@ -241,6 +254,7 @@ py_test(
srcs = ["remote_test.py"],
srcs_version = "PY2AND3",
deps = [
+ ":parameter_server",
":remote",
"//tensorflow/contrib/eager/python:tfe",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD
index 6f02c90368..97c299a911 100644
--- a/tensorflow/contrib/eager/python/examples/BUILD
+++ b/tensorflow/contrib/eager/python/examples/BUILD
@@ -6,6 +6,7 @@ package(default_visibility = ["//tensorflow:internal"])
py_library(
name = "examples_pip",
deps = [
+ "//tensorflow/contrib/eager/python/examples/densenet",
"//tensorflow/contrib/eager/python/examples/gan:mnist",
"//tensorflow/contrib/eager/python/examples/l2hmc",
"//tensorflow/contrib/eager/python/examples/l2hmc:neural_nets",
diff --git a/tensorflow/contrib/eager/python/examples/gan/BUILD b/tensorflow/contrib/eager/python/examples/gan/BUILD
index c61ec2dbae..d64c8eb9ce 100644
--- a/tensorflow/contrib/eager/python/examples/gan/BUILD
+++ b/tensorflow/contrib/eager/python/examples/gan/BUILD
@@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "py_binary")
py_binary(
name = "mnist",
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
index 9557479885..1c925e455b 100644
--- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
@@ -37,26 +37,43 @@ def get_default_hparams():
n_warmup_iters=3)
+def step(dynamics, optimizer, samples):
+ loss, grads, samples, _ = l2hmc.loss_and_grads(
+ dynamics, samples, loss_fn=l2hmc.compute_loss)
+ optimizer.apply_gradients(zip(grads, dynamics.variables))
+
+ return loss, samples
+
+
+# To be defunnable, the function cannot return an Operation, so the above
+# function is used for defun or eager, and this function is used in graph to be
+# able to run the gradient updates.
+def graph_step(dynamics, optimizer, samples):
+ loss, grads, samples, _ = l2hmc.loss_and_grads(
+ dynamics, samples, loss_fn=l2hmc.compute_loss)
+ train_op = optimizer.apply_gradients(zip(grads, dynamics.variables))
+
+ return train_op, loss, samples
+
+
def warmup(dynamics,
optimizer,
n_iters=1,
n_samples=200,
- loss_fn=l2hmc.compute_loss):
+ step_fn=step):
"""Warmup optimization to reduce overhead."""
samples = tf.random_normal(
shape=[n_samples, dynamics.x_dim], dtype=tf.float32)
for _ in range(n_iters):
- _, grads, samples, _ = l2hmc.loss_and_grads(
- dynamics, samples, loss_fn=loss_fn)
- optimizer.apply_gradients(zip(grads, dynamics.variables))
+ _, samples = step_fn(dynamics, optimizer, samples)
def fit(dynamics,
samples,
optimizer,
- loss_fn=l2hmc.compute_loss,
+ step_fn=step,
n_iters=5000,
verbose=True,
logdir=None):
@@ -66,9 +83,7 @@ def fit(dynamics,
summary_writer = tf.contrib.summary.create_file_writer(logdir)
for i in range(n_iters):
- loss, grads, samples, _ = l2hmc.loss_and_grads(
- dynamics, samples, loss_fn=loss_fn)
- optimizer.apply_gradients(zip(grads, dynamics.variables))
+ loss, samples = step_fn(dynamics, optimizer, samples)
if verbose:
print("Iteration %d: loss %.4f" % (i, loss))
@@ -130,51 +145,48 @@ class L2hmcBenchmark(tf.test.Benchmark):
"""Benchmark Graph performance."""
hparams = get_default_hparams()
- tf.reset_default_graph()
- with tf.Graph().as_default():
- energy_fn, _, _ = l2hmc.get_scg_energy_fn()
- dynamics = l2hmc.Dynamics(
- x_dim=hparams.x_dim,
- minus_loglikelihood_fn=energy_fn,
- n_steps=hparams.n_steps,
- eps=hparams.eps)
- x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim])
- loss, x_out, _ = l2hmc.compute_loss(dynamics, x)
-
- global_step = tf.Variable(0., name="global_step", trainable=False)
- learning_rate = tf.train.exponential_decay(
- hparams.learning_rate, global_step, 1000, 0.96, staircase=True)
- optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
- train_op = optimizer.minimize(loss, global_step=global_step)
-
- # Single thread; fairer comparison against eager
- session_conf = tf.ConfigProto(
- intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
-
- with tf.Session(config=session_conf) as sess:
- sess.run(tf.global_variables_initializer())
-
- # Warmup to reduce initialization effect when timing
- samples = npr.normal(size=[hparams.n_samples, hparams.x_dim])
- for _ in range(hparams.n_warmup_iters):
- _, _, _, _ = sess.run(
- [x_out, loss, train_op, learning_rate], feed_dict={x: samples})
-
- # Training
- start_time = time.time()
- for i in range(hparams.n_iters):
- samples, loss_np, _, _ = sess.run(
- [x_out, loss, train_op, learning_rate], feed_dict={x: samples})
- print("Iteration %d: loss %.4f" % (i, loss_np))
- wall_time = time.time() - start_time
- examples_per_sec = hparams.n_samples / wall_time
-
- self.report_benchmark(
- name="graph_train_%s" % ("gpu"
- if tf.test.is_gpu_available() else "cpu"),
- iters=hparams.n_iters,
- extras={"examples_per_sec": examples_per_sec},
- wall_time=wall_time)
+ tf.enable_resource_variables()
+ for sample_size in [10, 25, 50, 100, 200]:
+ hparams.n_samples = sample_size
+ tf.reset_default_graph()
+ with tf.Graph().as_default():
+ energy_fn, _, _ = l2hmc.get_scg_energy_fn()
+ x = tf.random_normal([hparams.n_samples, hparams.x_dim],
+ dtype=tf.float32)
+ dynamics = l2hmc.Dynamics(
+ x_dim=hparams.x_dim,
+ minus_loglikelihood_fn=energy_fn,
+ n_steps=hparams.n_steps,
+ eps=hparams.eps)
+ loss, _, _ = l2hmc.compute_loss(dynamics, x)
+
+ optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate)
+ train_op, loss, _ = graph_step(dynamics, optimizer, x)
+
+ # Single thread; fairer comparison against eager
+ session_conf = tf.ConfigProto(inter_op_parallelism_threads=1)
+
+ with tf.Session(config=session_conf) as sess:
+ sess.run(tf.global_variables_initializer())
+
+ # Warmup to reduce initialization effect when timing
+ for _ in range(hparams.n_warmup_iters):
+ _, _ = sess.run([train_op, loss])
+
+ # Training
+ start_time = time.time()
+ for i in range(hparams.n_iters):
+ _, loss_np = sess.run([train_op, loss])
+ print("Iteration %d: loss %.4f" % (i, loss_np))
+ wall_time = (time.time() - start_time) / hparams.n_iters
+ examples_per_sec = hparams.n_samples / wall_time
+
+ self.report_benchmark(
+ name="graph_train_%s_%d" %
+ ("gpu" if tf.test.is_gpu_available() else "cpu", sample_size),
+ iters=hparams.n_iters,
+ extras={"examples_per_sec": examples_per_sec},
+ wall_time=wall_time)
def benchmark_eager(self):
self._benchmark_eager()
@@ -186,32 +198,44 @@ class L2hmcBenchmark(tf.test.Benchmark):
"""Benchmark Eager performance."""
hparams = get_default_hparams()
- energy_fn, _, _ = l2hmc.get_scg_energy_fn()
- dynamics = l2hmc.Dynamics(
- x_dim=hparams.x_dim,
- minus_loglikelihood_fn=energy_fn,
- n_steps=hparams.n_steps,
- eps=hparams.eps)
- optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate)
- loss_fn = tfe.defun(l2hmc.compute_loss) if defun else l2hmc.compute_loss
-
- # Warmup to reduce initialization effect when timing
- warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters, loss_fn=loss_fn)
-
- # Training
- samples = tf.random_normal(
- shape=[hparams.n_samples, hparams.x_dim], dtype=tf.float32)
- start_time = time.time()
- fit(dynamics, samples, optimizer, loss_fn=loss_fn, n_iters=hparams.n_iters)
- wall_time = time.time() - start_time
- examples_per_sec = hparams.n_samples / wall_time
-
- self.report_benchmark(
- name="eager_train_%s%s" % ("gpu" if tf.test.is_gpu_available() else
- "cpu", "_defun" if defun else ""),
- iters=hparams.n_iters,
- extras={"examples_per_sec": examples_per_sec},
- wall_time=wall_time)
+ for sample_size in [10, 25, 50, 100, 200]:
+ hparams.n_samples = sample_size
+ energy_fn, _, _ = l2hmc.get_scg_energy_fn()
+ dynamics = l2hmc.Dynamics(
+ x_dim=hparams.x_dim,
+ minus_loglikelihood_fn=energy_fn,
+ n_steps=hparams.n_steps,
+ eps=hparams.eps)
+ optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate)
+ step_fn = tfe.defun(step) if defun else step
+
+ # Warmup to reduce initialization effect when timing
+ warmup(
+ dynamics,
+ optimizer,
+ n_iters=hparams.n_warmup_iters,
+ n_samples=hparams.n_samples,
+ step_fn=step_fn)
+
+ # Training
+ samples = tf.random_normal(
+ shape=[hparams.n_samples, hparams.x_dim], dtype=tf.float32)
+ start_time = time.time()
+ fit(dynamics,
+ samples,
+ optimizer,
+ step_fn=step_fn,
+ n_iters=hparams.n_iters)
+ wall_time = (time.time() - start_time) / hparams.n_iters
+ examples_per_sec = hparams.n_samples / wall_time
+
+ self.report_benchmark(
+ name="eager_train_%s%s_%d" %
+ ("gpu" if tf.test.is_gpu_available() else "cpu",
+ "_defun" if defun else "", sample_size),
+ iters=hparams.n_iters,
+ extras={"examples_per_sec": examples_per_sec},
+ wall_time=wall_time)
del dynamics
diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD
index 2f6cfdf31e..74ce9e84f0 100644
--- a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD
+++ b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD
@@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "py_binary")
py_binary(
name = "linear_regression",
diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD
index f83eb5c476..d500b632eb 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD
+++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD
@@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "py_binary")
py_binary(
name = "rnn_colorbot",
diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD
index 4b4792cd49..2cc2fcbfeb 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD
+++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD
@@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "py_binary")
py_binary(
name = "rnn_ptb",
diff --git a/tensorflow/contrib/eager/python/parameter_server.py b/tensorflow/contrib/eager/python/parameter_server.py
new file mode 100644
index 0000000000..3a9e7b027e
--- /dev/null
+++ b/tensorflow/contrib/eager/python/parameter_server.py
@@ -0,0 +1,289 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""EXPERIMENTAL utilities for parameter server training with eager execution.
+
+Note: this should eventually be merged with the distribution strategy for
+ParameterServer.
+"""
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import time
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training.checkpointable import base as checkpointable
+
+
+def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
+ """Creates a variable handle with information to do shape inference."""
+ container = ops.get_default_graph()._container # pylint: disable=protected-access
+ if container is None:
+ container = ""
+ handle = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
+ shared_name=shared_name,
+ name=name,
+ container=container)
+ if graph_mode:
+ return handle
+
+ with context.graph_mode(), ops.Graph().as_default() as graph:
+ h = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
+ shared_name=shared_name,
+ name=name,
+ container=container)
+
+ # Tensor._handle_data contains information for the shape-inference code to
+ # know the shape and dtype of the variable pointed to by a handle. Since
+ # shape inference doesn't run in eager mode we copy this data here for when
+ # the handle is captured by an eager mode function.
+ # pylint: disable=protected-access
+ if ops._USE_C_SHAPES:
+ handle._handle_data = resource_variable_ops.get_resource_handle_data(h)
+ else:
+ if h._handle_data is None:
+ ops.set_shape_and_handle_data_for_outputs(h.op)
+ handle._handle_data = h._handle_data
+ # pylint: enable=protected-access
+ # Clean up op->graph->op reference cycles.
+ ops.dismantle_graph(graph)
+ return handle
+
+
+class SharedVariable(resource_variable_ops.ResourceVariable):
+ """Experimental Variable designed for parameter server training.
+
+ A SharedVariable has a name and two instances of SharedVariable with the
+ same name will have the same value, even if they are in different Sessions,
+ as long as they are placed on the same device.
+
+ The storage associated with SharedVariables is also not deleted when they go
+ out of scope.
+ """
+
+ def __init__(self, # pylint: disable=super-init-not-called
+ initial_value=None,
+ trainable=True,
+ name=None,
+ dtype=None,
+ constraint=None,
+ initialize=True,
+ **unused_kwargs):
+ """Creates a variable.
+
+ Args:
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called.
+ (Note that initializer functions from init_ops.py must first be bound
+ to a shape before being used here.)
+ trainable: If `True`, automatically watches this variable on GradientTape
+ whenever it's used.
+ name: Optional name for the variable. Defaults to `'Variable'` and gets
+ uniquified automatically.
+ dtype: If set, initial_value will be converted to the given type.
+ If None, either the datatype will be kept (if initial_value is
+ a Tensor) or float32 will be used (if it is a Python object convertible
+ to a Tensor).
+ constraint: An optional projection function to be applied to the variable
+ after being updated by an `Optimizer` (e.g. used to implement norm
+ constraints or value constraints for layer weights). The function must
+ take as input the unprojected Tensor representing the value of the
+ variable and return the Tensor for the projected value
+ (which must have the same shape). Constraints are not safe to
+ use when doing asynchronous distributed training.
+ initialize: if True, runs initialization in eager execution; leaves the
+ variable uninitialized otherwise.
+
+ Raises:
+ ValueError: If the initial value is not specified, or does not have a
+ shape and `validate_shape` is `True`.
+ """
+ if initial_value is None:
+ raise ValueError("initial_value must be specified.")
+ init_from_fn = callable(initial_value)
+
+ if isinstance(initial_value, ops.Tensor) and hasattr(
+ initial_value, "graph") and initial_value.graph.building_function:
+ raise ValueError("Tensor-typed variable initializers must either be "
+ "wrapped in an init_scope or callable "
+ "(e.g., `tf.Variable(lambda : "
+ "tf.truncated_normal([10, 40]))`) when building "
+ "functions. Please file a feature request if this "
+ "restriction inconveniences you.")
+
+ if constraint is not None and not callable(constraint):
+ raise ValueError("The `constraint` argument must be a callable.")
+
+ if isinstance(initial_value, checkpointable.CheckpointInitialValue):
+ self._maybe_initialize_checkpointable()
+ self._update_uid = initial_value.checkpoint_position.restore_uid
+ initial_value = initial_value.wrapped_value
+
+ self._trainable = trainable
+ self._save_slice_info = None
+ # Store the graph key so optimizers know how to only retrieve variables from
+ # this graph.
+ self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
+ with ops.init_scope():
+ self._in_graph_mode = not context.executing_eagerly()
+ with ops.name_scope(name, "Variable", []
+ if init_from_fn else [initial_value]) as name:
+ # pylint: disable=protected-access
+ handle_name = ops._name_from_scope_name(name)
+ shared_name = handle_name
+ if init_from_fn:
+ # Use attr_scope and device(None) to simulate the behavior of
+ # colocate_with when the variable we want to colocate with doesn't
+ # yet exist.
+ if self._in_graph_mode:
+ with ops.name_scope("Initializer"), ops.device(None):
+ initial_value = ops.convert_to_tensor(
+ initial_value(), name="initial_value", dtype=dtype)
+ self._handle = _eager_safe_variable_handle(
+ shape=initial_value.get_shape(),
+ dtype=initial_value.dtype.base_dtype,
+ shared_name=shared_name,
+ name=name,
+ graph_mode=self._in_graph_mode)
+ self._shape = initial_value.get_shape()
+ else:
+ initial_value = initial_value()
+ with ops.name_scope("Initializer"):
+ initial_value = ops.convert_to_tensor(
+ initial_value, name="initial_value", dtype=dtype)
+ self._handle = _eager_safe_variable_handle(
+ shape=initial_value.get_shape(),
+ dtype=initial_value.dtype.base_dtype,
+ shared_name=shared_name,
+ name=name,
+ graph_mode=False)
+ self._shape = initial_value.get_shape()
+ # pylint: enable=protected-access
+
+ # Or get the initial value from a Tensor or Python object.
+ else:
+ with ops.name_scope("Initializer"):
+ initial_value = ops.convert_to_tensor(
+ initial_value, name="initial_value", dtype=dtype)
+ # pylint: disable=protected-access
+ if (self._in_graph_mode and initial_value is not None and
+ initial_value.op._get_control_flow_context() is not None):
+ raise ValueError(
+ "Initializer for variable %s is from inside a control-flow "
+ "construct, such as a loop or conditional. When creating a "
+ "variable inside a loop or conditional, use a lambda as the "
+ "initializer." % name)
+ # pylint: enable=protected-access
+ self._handle = _eager_safe_variable_handle(
+ shape=initial_value.get_shape(),
+ dtype=initial_value.dtype.base_dtype,
+ shared_name=shared_name,
+ name=name,
+ graph_mode=self._in_graph_mode)
+ self._shape = initial_value.get_shape()
+
+ self._unique_id = shared_name
+ self._initial_value = initial_value if self._in_graph_mode else None
+ self._handle_name = handle_name + ":0"
+ self._dtype = initial_value.dtype.base_dtype
+ self._constraint = constraint
+
+ if self._in_graph_mode:
+ with ops.name_scope("IsInitialized"):
+ self._is_initialized_op = (
+ resource_variable_ops.var_is_initialized_op(self._handle))
+ if initial_value is not None:
+ with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
+ self._initializer_op = (
+ resource_variable_ops.assign_variable_op(
+ self._handle,
+ self._try_guard_against_uninitialized_dependencies(
+ initial_value),
+ name=n))
+ with ops.name_scope("Read"), ops.colocate_with(self._handle):
+ # Manually assign reads to the handle's device to avoid log
+ # messages.
+ with ops.device(self._handle.device):
+ value = self._read_variable_op()
+ self._graph_element = value
+ self._cached_value = None
+ else:
+ if initialize:
+ resource_variable_ops.assign_variable_op(self._handle,
+ initial_value)
+ self._is_initialized_op = None
+ self._initializer_op = None
+ self._graph_element = None
+ self._cached_value = None
+
+ self._handle_deleter = None
+ self._cached_shape_as_list = None
+
+
+@contextlib.contextmanager
+def parameter_server_scope(is_chief, ps_job_name, num_ps_tasks):
+ """Strategy to use parameter servers in eager.
+
+ Creates SharedVariable objects for variables created in this scope. These
+ SharedVariable objects will be placed round-robin on the parameter servers
+ specified by the ps_job_name and num_ps_tasks arguments.
+
+ To use parameter servers you need only to wrap your model initialization in
+ this scope:
+
+ ```
+ with tf.contrib.eager.parameter_server_scope(
+ is_chief, ps_job_name, num_ps_tasks):
+ my_model = tf.keras.Sequential([...]) # Or
+ input = tf.keras.Input(...)
+ ....
+ my_model = tf.keras.Model(input, output)
+ my_model.compile(...)
+ # or other usages of the model.
+ ```
+
+ Args:
+ is_chief: Boolean. Whether this worker is responsible for initializing
+ variables.
+ ps_job_name: The name of the ps job in this cluster.
+ num_ps_tasks: The number of ps tasks to use.
+
+ Yields:
+ a context manager.
+ """
+ # Note: capturing in a list to allow assignment.
+ ps_index = [0]
+
+ def variable_creator_scope(unused_next_creator, **kwargs):
+ kwargs["initialize"] = is_chief
+ with ops.device(
+ "/job:%s/task:%s" % (ps_job_name, ps_index[0] % num_ps_tasks)):
+ ps_index[0] += 1
+ v = SharedVariable(**kwargs)
+ if not is_chief:
+ while not resource_variable_ops.var_is_initialized_op(v.handle):
+ time.sleep(10)
+ return v
+
+ with variable_scope.variable_creator_scope(variable_creator_scope):
+ yield
diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py
index 13029db975..ba6fe9701d 100644
--- a/tensorflow/contrib/eager/python/remote_test.py
+++ b/tensorflow/contrib/eager/python/remote_test.py
@@ -23,6 +23,7 @@ import os
import numpy as np
+from tensorflow.contrib.eager.python import parameter_server
from tensorflow.contrib.eager.python import remote
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
@@ -33,6 +34,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
@@ -120,6 +122,24 @@ class RemoteExecutionTest(test.TestCase):
y = math_ops.matmul(x1, x2)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
+ def testParameterServer(self):
+ with parameter_server.parameter_server_scope(
+ is_chief=True, ps_job_name=JOB_NAME, num_ps_tasks=3):
+ v0 = variables.Variable([1.0], name="v0")
+ v1 = variables.Variable([2.0], name="v1")
+ v0.assign(v0 * v1)
+ self.assertAllEqual(v0.read_value(), [2.0])
+ self.assertAllEqual(v0.device,
+ "/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME)
+ self.assertAllEqual(v1.device,
+ "/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME)
+ v1.assign_add(v1)
+ # Simulate aliasing another variable of the same name as v1
+ with ops.device("/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
+ v1_replica = parameter_server.SharedVariable(
+ [1.0], name="v1", initialize=False)
+ self.assertAllEqual(v1_replica.read_value(), [4.0])
+
@run_sync_and_async
def testSimpleWeightRead(self):
"""Basic remote eager weight read."""
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 6db311d52d..1ea00fb7f3 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -132,21 +132,11 @@ py_library(
srcs = ["python/estimator/dnn_with_layer_annotations.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:layers",
- "//tensorflow/python:nn",
- "//tensorflow/python:partitioned_variables",
- "//tensorflow/python:summary",
- "//tensorflow/python:variable_scope",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:head",
"//tensorflow/python/estimator:model_fn",
"//tensorflow/python/estimator:optimizers",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/ops/losses",
- "//tensorflow/python/saved_model:utils",
],
)
@@ -162,22 +152,13 @@ py_test(
],
deps = [
":dnn_with_layer_annotations",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:data_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:dnn",
"//tensorflow/python/estimator:dnn_testing_utils",
"//tensorflow/python/estimator:export_export",
"//tensorflow/python/estimator:numpy_io",
"//tensorflow/python/estimator:pandas_io",
"//tensorflow/python/estimator:prediction_keys",
- "//tensorflow/python/feature_column",
"@six_archive//:six",
],
)
@@ -283,9 +264,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:summary",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:exporter",
],
)
@@ -297,7 +276,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":exporter",
- "//tensorflow/python:platform",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:exporter",
],
@@ -502,7 +481,6 @@ py_library(
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:head",
"//tensorflow/python/estimator:optimizers",
- "//tensorflow/python/ops/losses",
"@six_archive//:six",
],
)
@@ -557,13 +535,10 @@ py_library(
srcs = ["python/estimator/saved_model_estimator.py"],
deps = [
":export",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:export",
"//tensorflow/python/estimator:model_fn",
- "//tensorflow/python/saved_model",
],
)
@@ -578,16 +553,7 @@ py_test(
deps = [
":export",
":saved_model_estimator",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:metrics",
- "//tensorflow/python:platform",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:export_export",
"//tensorflow/python/estimator:export_output",
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index 78914ecaca..419609b1af 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -76,7 +76,7 @@ _allowed_symbols = [
'stop_if_no_decrease_hook',
'build_raw_supervised_input_receiver_fn',
'build_supervised_input_receiver_fn_from_input_fn',
- 'SavedModelEstimator'
+ 'SavedModelEstimator',
'DNNClassifierWithLayerAnnotations',
'DNNRegressorWithLayerAnnotations',
]
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
index 7ed77bcce6..a1f1c5f3d7 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees
+from tensorflow.python.estimator.canned import head as head_lib
def _validate_input_fn_and_repeat_dataset(train_input_fn):
@@ -33,7 +34,19 @@ def _validate_input_fn_and_repeat_dataset(train_input_fn):
return _input_fn
-class _BoostedTreesEstimator(estimator.Estimator):
+def _is_classification_head(head):
+ """Infers if the head is a classification head."""
+ # Check using all classification heads defined in canned/head.py. However, it
+ # is not a complete list - it does not check for other classification heads
+ # not defined in the head library.
+ # pylint: disable=protected-access
+ return isinstance(head,
+ (head_lib._BinaryLogisticHeadWithSigmoidCrossEntropyLoss,
+ head_lib._MultiClassHeadWithSoftmaxCrossEntropyLoss))
+ # pylint: enable=protected-access
+
+
+class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase): # pylint: disable=protected-access
"""An Estimator for Tensorflow Boosted Trees models."""
def __init__(self,
@@ -96,9 +109,12 @@ class _BoostedTreesEstimator(estimator.Estimator):
negative gain). For pre and post pruning, you MUST provide
tree_complexity >0.
+ Raises:
+ ValueError: when wrong arguments are given or unsupported functionalities
+ are requested.
"""
- # pylint:disable=protected-access
# HParams for the model.
+ # pylint: disable=protected-access
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
tree_complexity, min_node_weight, center_bias, pruning_mode)
@@ -115,8 +131,14 @@ class _BoostedTreesEstimator(estimator.Estimator):
config=config)
super(_BoostedTreesEstimator, self).__init__(
- model_fn=_model_fn, model_dir=model_dir, config=config)
- # pylint:enable=protected-access
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=feature_columns,
+ head=head,
+ center_bias=center_bias,
+ is_classification=_is_classification_head(head))
+ # pylint: enable=protected-access
def boosted_trees_classifier_train_in_memory(
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
index b1581f3750..e23d9c0fc4 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
@@ -360,5 +360,79 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
[pred['predictions'] for pred in predictions])
+class BoostedTreesDebugOutputTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._head = canned_boosted_trees._create_regression_head(label_dimension=1)
+ self._feature_columns = {
+ feature_column.bucketized_column(
+ feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+ BUCKET_BOUNDARIES) for i in range(NUM_FEATURES)
+ }
+
+ def testContribEstimatorThatDFCIsInPredictions(self):
+ # pylint:disable=protected-access
+ head = canned_boosted_trees._create_regression_head(label_dimension=1)
+ train_input_fn = _make_train_input_fn(is_classification=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees._BoostedTreesEstimator(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ head=head,
+ n_trees=1,
+ max_depth=5,
+ center_bias=True)
+ # pylint:enable=protected-access
+
+ num_steps = 100
+ # Train for a few steps. Validate debug outputs in prediction dicts.
+ est.train(train_input_fn, steps=num_steps)
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn)
+ biases, dfcs = zip(*[(pred['bias'], pred['dfc'])
+ for pred in debug_predictions])
+ self.assertAllClose([1.8] * 5, biases)
+ self.assertAllClose(({
+ 0: -0.070499420166015625,
+ 1: -0.095000028610229492,
+ 2: 0.0
+ }, {
+ 0: -0.53763031959533691,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }, {
+ 0: -0.51756942272186279,
+ 1: -0.095000028610229492,
+ 2: 0.0
+ }, {
+ 0: 0.1563495397567749,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }, {
+ 0: 0.96934974193572998,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }), dfcs)
+
+ # Assert sum(dfcs) + bias == predictions.
+ expected_predictions = [[1.6345005], [1.32570302], [1.1874305],
+ [2.01968288], [2.83268309]]
+ predictions = [
+ [sum(dfc.values()) + bias] for (dfc, bias) in zip(dfcs, biases)
+ ]
+ self.assertAllClose(expected_predictions, predictions)
+
+ # Test when user doesn't include bias or dfc in predict_keys.
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn, predict_keys=['predictions'])
+ for prediction_dict in debug_predictions:
+ self.assertTrue('bias' in prediction_dict)
+ self.assertTrue('dfc' in prediction_dict)
+ self.assertTrue('predictions' in prediction_dict)
+ self.assertEqual(len(prediction_dict), 3)
+
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
index 152431d1b2..5faf0aacfe 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
@@ -24,7 +24,6 @@ import pickle
from google.protobuf.any_pb2 import Any
from tensorflow.python.estimator import estimator
-from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator.canned import dnn
from tensorflow.python.feature_column import feature_column as feature_column_lib
from tensorflow.python.framework import ops
@@ -68,7 +67,7 @@ def _to_any_wrapped_tensor_info(tensor):
return any_buf
-def make_input_layer_with_layer_annotations(original_input_layer, mode):
+def make_input_layer_with_layer_annotations(original_input_layer):
"""Make an input_layer replacement function that adds layer annotations."""
def input_layer_with_layer_annotations(features,
@@ -76,7 +75,9 @@ def make_input_layer_with_layer_annotations(original_input_layer, mode):
weight_collections=None,
trainable=True,
cols_to_vars=None,
- cols_to_output_tensors=None):
+ scope=None,
+ cols_to_output_tensors=None,
+ from_template=False):
"""Returns a dense `Tensor` as input layer based on given `feature_columns`.
Generally a single example in training data is described with
@@ -112,9 +113,12 @@ def make_input_layer_with_layer_annotations(original_input_layer, mode):
'some_variable:0' shape=(5, 10), <tf.Variable 'some_variable:1'
shape=(5, 10)]} If a column creates no variables, its value will be an
empty list.
+ scope: A name or variable scope to use
cols_to_output_tensors: If not `None`, must be a dictionary that will be
filled with a mapping from '_FeatureColumn' to the associated output
`Tensor`s.
+ from_template: True if the method is being instantiated from a
+ `make_template`.
Returns:
A `Tensor` which represents input layer of a model. Its shape
@@ -132,47 +136,45 @@ def make_input_layer_with_layer_annotations(original_input_layer, mode):
weight_collections=weight_collections,
trainable=trainable,
cols_to_vars=cols_to_vars,
- cols_to_output_tensors=local_cols_to_output_tensors)
+ scope=scope,
+ cols_to_output_tensors=local_cols_to_output_tensors,
+ from_template=from_template)
if cols_to_output_tensors is not None:
cols_to_output_tensors = local_cols_to_output_tensors
- if mode and mode == model_fn.ModeKeys.PREDICT:
- # Only annotate in PREDICT mode.
-
- # Annotate features.
- # These are the parsed Tensors, before embedding.
-
- # Only annotate features used by FeatureColumns.
- # We figure which ones are used by FeatureColumns by creating a parsing
- # spec and looking at the keys.
- spec = feature_column_lib.make_parse_example_spec(feature_columns)
- for key in spec.keys():
- tensor = features[key]
- ops.add_to_collection(
- LayerAnnotationsCollectionNames.keys(
- LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES), key)
- ops.add_to_collection(
- LayerAnnotationsCollectionNames.values(
- LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES),
- _to_any_wrapped_tensor_info(tensor))
-
- # Annotate feature columns.
- for column in feature_columns:
- # TODO(cyfoo): Find a better way to serialize and deserialize
- # _FeatureColumn.
- ops.add_to_collection(LayerAnnotationsCollectionNames.FEATURE_COLUMNS,
- serialize_feature_column(column))
-
- for column, tensor in local_cols_to_output_tensors.items():
- ops.add_to_collection(
- LayerAnnotationsCollectionNames.keys(
- LayerAnnotationsCollectionNames.PROCESSED_FEATURES),
- column.name)
- ops.add_to_collection(
- LayerAnnotationsCollectionNames.values(
- LayerAnnotationsCollectionNames.PROCESSED_FEATURES),
- _to_any_wrapped_tensor_info(tensor))
+ # Annotate features.
+ # These are the parsed Tensors, before embedding.
+
+ # Only annotate features used by FeatureColumns.
+ # We figure which ones are used by FeatureColumns by creating a parsing
+ # spec and looking at the keys.
+ spec = feature_column_lib.make_parse_example_spec(feature_columns)
+ for key in spec.keys():
+ tensor = ops.convert_to_tensor(features[key])
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.keys(
+ LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES), key)
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.values(
+ LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES),
+ _to_any_wrapped_tensor_info(tensor))
+
+ # Annotate feature columns.
+ for column in feature_columns:
+ # TODO(cyfoo): Find a better way to serialize and deserialize
+ # _FeatureColumn.
+ ops.add_to_collection(LayerAnnotationsCollectionNames.FEATURE_COLUMNS,
+ serialize_feature_column(column))
+
+ for column, tensor in local_cols_to_output_tensors.items():
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.keys(
+ LayerAnnotationsCollectionNames.PROCESSED_FEATURES), column.name)
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.values(
+ LayerAnnotationsCollectionNames.PROCESSED_FEATURES),
+ _to_any_wrapped_tensor_info(tensor))
return input_layer
@@ -301,9 +303,9 @@ def DNNClassifierWithLayerAnnotations( # pylint: disable=invalid-name
def _model_fn(features, labels, mode, config):
with _monkey_patch(
- feature_column_lib, 'input_layer',
- make_input_layer_with_layer_annotations(feature_column_lib.input_layer,
- mode)):
+ feature_column_lib, '_internal_input_layer',
+ make_input_layer_with_layer_annotations(
+ feature_column_lib._internal_input_layer)): # pylint: disable=protected-access
return original.model_fn(features, labels, mode, config)
return estimator.Estimator(
@@ -422,9 +424,9 @@ def DNNRegressorWithLayerAnnotations( # pylint: disable=invalid-name
def _model_fn(features, labels, mode, config):
with _monkey_patch(
- feature_column_lib, 'input_layer',
- make_input_layer_with_layer_annotations(feature_column_lib.input_layer,
- mode)):
+ feature_column_lib, '_internal_input_layer',
+ make_input_layer_with_layer_annotations(
+ feature_column_lib._internal_input_layer)): # pylint: disable=protected-access
return original.model_fn(features, labels, mode, config)
return estimator.Estimator(
diff --git a/tensorflow/contrib/estimator/python/estimator/early_stopping.py b/tensorflow/contrib/estimator/python/estimator/early_stopping.py
index 3eab21d5ac..cafe8279c7 100644
--- a/tensorflow/contrib/estimator/python/estimator/early_stopping.py
+++ b/tensorflow/contrib/estimator/python/estimator/early_stopping.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import operator
import os
@@ -56,6 +57,13 @@ def make_early_stopping_hook(estimator,
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
```
+ Caveat: Current implementation supports early-stopping both training and
+ evaluation in local mode. In distributed mode, training can be stopped but
+ evaluation (where it's a separate job) will indefinitely wait for new model
+ checkpoints to evaluate, so you will need other means to detect and stop it.
+ Early-stopping evaluation in distributed mode requires changes in
+ `train_and_evaluate` API and will be addressed in a future revision.
+
Args:
estimator: A `tf.estimator.Estimator` instance.
should_stop_fn: `callable`, function that takes no arguments and returns a
@@ -108,6 +116,13 @@ def stop_if_higher_hook(estimator,
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
```
+ Caveat: Current implementation supports early-stopping both training and
+ evaluation in local mode. In distributed mode, training can be stopped but
+ evaluation (where it's a separate job) will indefinitely wait for new model
+ checkpoints to evaluate, so you will need other means to detect and stop it.
+ Early-stopping evaluation in distributed mode requires changes in
+ `train_and_evaluate` API and will be addressed in a future revision.
+
Args:
estimator: A `tf.estimator.Estimator` instance.
metric_name: `str`, metric to track. "loss", "accuracy", etc.
@@ -157,6 +172,13 @@ def stop_if_lower_hook(estimator,
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
```
+ Caveat: Current implementation supports early-stopping both training and
+ evaluation in local mode. In distributed mode, training can be stopped but
+ evaluation (where it's a separate job) will indefinitely wait for new model
+ checkpoints to evaluate, so you will need other means to detect and stop it.
+ Early-stopping evaluation in distributed mode requires changes in
+ `train_and_evaluate` API and will be addressed in a future revision.
+
Args:
estimator: A `tf.estimator.Estimator` instance.
metric_name: `str`, metric to track. "loss", "accuracy", etc.
@@ -206,6 +228,13 @@ def stop_if_no_increase_hook(estimator,
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
```
+ Caveat: Current implementation supports early-stopping both training and
+ evaluation in local mode. In distributed mode, training can be stopped but
+ evaluation (where it's a separate job) will indefinitely wait for new model
+ checkpoints to evaluate, so you will need other means to detect and stop it.
+ Early-stopping evaluation in distributed mode requires changes in
+ `train_and_evaluate` API and will be addressed in a future revision.
+
Args:
estimator: A `tf.estimator.Estimator` instance.
metric_name: `str`, metric to track. "loss", "accuracy", etc.
@@ -256,6 +285,13 @@ def stop_if_no_decrease_hook(estimator,
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
```
+ Caveat: Current implementation supports early-stopping both training and
+ evaluation in local mode. In distributed mode, training can be stopped but
+ evaluation (where it's a separate job) will indefinitely wait for new model
+ checkpoints to evaluate, so you will need other means to detect and stop it.
+ Early-stopping evaluation in distributed mode requires changes in
+ `train_and_evaluate` API and will be addressed in a future revision.
+
Args:
estimator: A `tf.estimator.Estimator` instance.
metric_name: `str`, metric to track. "loss", "accuracy", etc.
@@ -306,7 +342,8 @@ def read_eval_metrics(eval_dir):
metrics[value.tag] = value.simple_value
if metrics:
eval_metrics_dict[event.step] = metrics
- return eval_metrics_dict
+ return collections.OrderedDict(
+ sorted(eval_metrics_dict.items(), key=lambda t: t[0]))
def _stop_if_threshold_crossed_hook(estimator, metric_name, threshold,
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py
index 66c46e66b7..49f7bbd320 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks.py
@@ -53,6 +53,7 @@ class InMemoryEvaluatorHook(training.SessionRunHook):
```
Current limitations of this approach are:
+
* It doesn't support multi-node distributed mode.
* It doesn't support saveable objects other than variables (such as boosted
tree support)
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py
index c6c6cad95a..62ffad56da 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py
@@ -294,7 +294,7 @@ class InMemoryEvaluatorHookTest(test.TestCase):
def model_fn(features, labels, mode):
_, _ = features, labels
- w = variables.Variable(
+ w = variables.VariableV1(
initial_value=[0.],
trainable=False,
collections=[ops.GraphKeys.SAVEABLE_OBJECTS])
diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD
index 9e1f14f990..e344d7a23b 100644
--- a/tensorflow/contrib/factorization/BUILD
+++ b/tensorflow/contrib/factorization/BUILD
@@ -64,7 +64,6 @@ tf_custom_op_py_library(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
- "//tensorflow/python/estimator",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/feature_column:feature_column_py",
"//third_party/py/numpy",
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
index b1820c10c8..9b0b9b1e1b 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
@@ -186,7 +186,7 @@ class WithShapeTest(test.TestCase):
unexpected_shapes)
def test_with_shape_2x2_with_partial_expected_shape(self):
- with self.test_session():
+ with self.cached_session():
value = [[42, 43], [44, 45]]
actual_shape = [2, 2]
tensor = constant_op.constant(value, shape=actual_shape)
diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py
index f9b0efd1da..c223df5b6e 100644
--- a/tensorflow/contrib/framework/python/ops/variables_test.py
+++ b/tensorflow/contrib/framework/python/ops/variables_test.py
@@ -192,7 +192,7 @@ class GlobalStepTest(test.TestCase):
def test_invalid_dtype(self):
with ops.Graph().as_default() as g:
self.assertEquals(None, variables_lib2.get_global_step())
- variables_lib.Variable(
+ variables_lib.VariableV1(
0.0,
trainable=False,
dtype=dtypes.float32,
@@ -205,7 +205,7 @@ class GlobalStepTest(test.TestCase):
def test_invalid_shape(self):
with ops.Graph().as_default() as g:
self.assertEquals(None, variables_lib2.get_global_step())
- variables_lib.Variable(
+ variables_lib.VariableV1(
[0],
trainable=False,
dtype=dtypes.int32,
@@ -229,7 +229,7 @@ class GlobalStepTest(test.TestCase):
def test_get_global_step(self):
with ops.Graph().as_default() as g:
self.assertEquals(None, variables_lib2.get_global_step())
- variables_lib.Variable(
+ variables_lib.VariableV1(
0,
trainable=False,
dtype=dtypes.int32,
@@ -607,10 +607,10 @@ class ModelVariablesTest(test.TestCase):
with self.cached_session():
with variable_scope.variable_scope('A'):
variables_lib2.local_variable([5])
- a = variables_lib.Variable([5])
+ a = variables_lib.VariableV1([5])
with variable_scope.variable_scope('B'):
variables_lib2.local_variable([5])
- b = variables_lib.Variable([5])
+ b = variables_lib.VariableV1([5])
self.assertEquals([a], variables_lib2.get_trainable_variables('A'))
self.assertEquals([b], variables_lib2.get_trainable_variables('B'))
@@ -953,7 +953,7 @@ class AssignFromCheckpointTest(test.TestCase):
# Create a set of variables to save in the checkpoint.
for var_name in var_names_to_values:
var_value = var_names_to_values[var_name]
- var_list.append(variables_lib.Variable(var_value, name=var_name))
+ var_list.append(variables_lib.VariableV1(var_value, name=var_name))
saver = saver_lib.Saver(var_list)
init_op = variables_lib.variables_initializer(var_list)
sess.run(init_op)
@@ -1106,7 +1106,7 @@ class AssignFromCheckpointFnTest(test.TestCase):
# Create a set of variables to save in the checkpoint.
for var_name in var_names_to_values:
var_value = var_names_to_values[var_name]
- var_list.append(variables_lib.Variable(var_value, name=var_name))
+ var_list.append(variables_lib.VariableV1(var_value, name=var_name))
saver = saver_lib.Saver(var_list)
init_op = variables_lib.variables_initializer(var_list)
sess.run(init_op)
@@ -1297,7 +1297,7 @@ class AssignFromCheckpointFnTest(test.TestCase):
class ZeroInitializerOpTest(test.TestCase):
def _testZeroInitializer(self, shape, initializer, use_init):
- var = variables_lib.Variable(initializer)
+ var = variables_lib.VariableV1(initializer)
var_zero = variables_lib2.zero_initializer(var)
with self.cached_session() as sess:
with self.assertRaisesOpError('Attempting to use uninitialized value'):
@@ -1350,12 +1350,12 @@ class FilterVariablesTest(test.TestCase):
g = ops.Graph()
with g.as_default():
var_list = []
- var_list.append(variables_lib.Variable(0, name='conv1/weights'))
- var_list.append(variables_lib.Variable(0, name='conv1/biases'))
- var_list.append(variables_lib.Variable(0, name='conv2/weights'))
- var_list.append(variables_lib.Variable(0, name='conv2/biases'))
- var_list.append(variables_lib.Variable(0, name='clfs/weights'))
- var_list.append(variables_lib.Variable(0, name='clfs/biases'))
+ var_list.append(variables_lib.VariableV1(0, name='conv1/weights'))
+ var_list.append(variables_lib.VariableV1(0, name='conv1/biases'))
+ var_list.append(variables_lib.VariableV1(0, name='conv2/weights'))
+ var_list.append(variables_lib.VariableV1(0, name='conv2/biases'))
+ var_list.append(variables_lib.VariableV1(0, name='clfs/weights'))
+ var_list.append(variables_lib.VariableV1(0, name='clfs/biases'))
self._var_list = var_list
def _test_filter_variables(self,
diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD
index 0f0813c07f..490da9b33b 100644
--- a/tensorflow/contrib/fused_conv/BUILD
+++ b/tensorflow/contrib/fused_conv/BUILD
@@ -17,11 +17,14 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_kernel_library",
+ "tf_custom_op_library",
+ "tf_gen_op_libs",
+ "tf_gen_op_wrapper_py",
+)
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
-load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
-load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
tf_custom_op_py_library(
@@ -109,13 +112,13 @@ tf_gen_op_wrapper_py(
deps = [":fused_conv2d_bias_activation_op_op_lib"],
)
-cuda_py_test(
- name = "fused_conv2d_bias_activation_op_test",
- size = "large",
- srcs = ["python/ops/fused_conv2d_bias_activation_op_test.py"],
- additional_deps = [
+py_library(
+ name = "fused_conv2d_bias_activation_op_test_base",
+ testonly = 1,
+ srcs = ["python/ops/fused_conv2d_bias_activation_op_test_base.py"],
+ visibility = ["//tensorflow/compiler/tf2xla:internal"],
+ deps = [
":fused_conv_py",
- "//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
@@ -128,16 +131,27 @@ cuda_py_test(
"//tensorflow/python:random_ops",
"//tensorflow/python:training",
"//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+cuda_py_test(
+ name = "fused_conv2d_bias_activation_op_test",
+ size = "large",
+ srcs = ["python/ops/fused_conv2d_bias_activation_op_test.py"],
+ additional_deps = [
+ ":fused_conv2d_bias_activation_op_test_base",
+ "//tensorflow/python:client_testlib",
],
tags = [
- "manual",
- "requires_cudnn6",
+ "no_pip",
+ "requires-gpu-sm70",
],
)
cuda_py_test(
name = "fused_conv2d_bias_activation_benchmark",
- size = "large",
srcs = ["python/ops/fused_conv2d_bias_activation_benchmark.py"],
additional_deps = [
":fused_conv_py",
@@ -155,7 +169,6 @@ cuda_py_test(
],
main = "python/ops/fused_conv2d_bias_activation_benchmark.py",
tags = [
- "manual",
- "requires_cudnn6",
+ "requires-gpu-sm70",
],
)
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
index e9e6464d06..93b1aaa85e 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -111,8 +111,8 @@ class FusedConv2DBiasActivationOp : public OpKernel {
context,
(GetTensorDim(strides, data_format_, 'N') == 1 &&
GetTensorDim(strides, data_format_, 'C') == 1),
- errors::InvalidArgument("Convolutional strides are not supported in "
- "the batch or depth dimensions."));
+ errors::Unimplemented("Convolutional strides are not supported in "
+ "the batch and depth dimensions."));
// Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here.
constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
index 0185ef662c..e5c8a34fc1 100644
--- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
+++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
@@ -12,898 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Functional tests for fused conv2d bias and activation operation."""
+
+"""Tests for fused convolutions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors_impl
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_array_ops
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import random_ops
+from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op_test_base as test_base
from tensorflow.python.platform import test
-from tensorflow.python.platform import tf_logging
-
-
-def GetShrunkInceptionShapes(shrink=10):
- """Iterator for smaller versions of convolution shapes in 2015 Inception.
-
- Relative to inception, each depth value is `depth // shrink`.
-
- Args:
- shrink: Factor to shrink each depth value by relative to Inception.
-
- Yields:
- Tuple (input_size, filter_size, out_size, stride, padding), the convolution
- parameters of Inception layers.
- """
- input_sizes = [[4, 5, 5, 1248], [4, 8, 8, 384], [4, 8, 8, 384], [
- 4, 8, 8, 2048
- ], [4, 8, 8, 448], [4, 8, 8, 2048], [4, 8, 8, 2048], [4, 8, 8, 2048], [
- 4, 8, 8, 1760
- ], [4, 8, 8, 1760], [4, 8, 8, 1760], [4, 8, 8, 1760], [4, 17, 17, 192], [
- 4, 17, 17, 192
- ], [4, 17, 17, 1248], [4, 17, 17, 128], [4, 17, 17, 1248], [4, 17, 17, 224], [
- 4, 17, 17, 192
- ], [4, 17, 17, 192], [4, 17, 17, 1216], [4, 17, 17, 1216], [4, 17, 17, 224], [
- 4, 17, 17, 192
- ], [4, 17, 17, 192], [4, 17, 17, 1152], [4, 17, 17, 1152], [4, 17, 17, 192], [
- 4, 17, 17, 160
- ], [4, 17, 17, 1152], [4, 17, 17, 1024], [4, 17, 17, 128], [4, 17, 17, 1024],
- [4, 17, 17, 128], [4, 17, 17, 1024], [4, 17, 17, 128], [
- 4, 17, 17, 768
- ], [4, 17, 17, 128], [4, 17, 17, 128], [4, 17, 17, 768],
- [4, 17, 17, 768], [4, 35, 35, 96], [4, 35, 35, 288], [
- 4, 35, 35, 64
- ], [4, 35, 35, 288], [4, 35, 35, 256], [4, 35, 35, 48], [
- 4, 35, 35, 256
- ], [4, 35, 35, 96], [4, 35, 35, 192], [4, 35, 35, 192], [
- 4, 35, 35, 192
- ], [4, 73, 73, 64], [4, 73, 73, 64], [4, 147, 147, 24]]
- filter_sizes = [[1, 1, 1248, 128], [1, 3, 384, 384], [3, 1, 384, 384], [
- 1, 1, 2048, 192
- ], [3, 3, 448, 384], [1, 1, 2048, 320], [1, 1, 2048, 448], [1, 1, 2048, 384],
- [1, 1, 1760, 384], [1, 1, 1760, 192], [1, 1, 1760, 448], [
- 1, 1, 1760, 320
- ], [3, 3, 192, 192], [3, 3, 192, 192], [1, 1, 1248, 192], [
- 3, 3, 128, 320
- ], [1, 1, 1248, 128], [1, 3, 224, 224], [3, 1, 192, 256], [
- 1, 3, 192, 256
- ], [1, 1, 1216, 192], [1, 1, 1216, 96], [3, 1, 224, 224], [
- 3, 3, 192, 224
- ], [1, 3, 192, 192], [1, 1, 1152, 192], [1, 1, 1152, 128], [
- 3, 1, 192, 192
- ], [3, 3, 160, 192], [1, 1, 1152, 160], [1, 1, 1024, 128], [
- 1, 3, 128, 192
- ], [1, 1, 1024, 160], [3, 1, 128, 192], [1, 1, 1024, 256], [
- 3, 1, 128, 128
- ], [1, 1, 768, 192], [1, 3, 128, 128], [3, 3, 128, 128], [
- 1, 1, 768, 128
- ], [1, 1, 768, 320], [3, 3, 96, 96], [3, 3, 288, 384], [
- 3, 3, 64, 96
- ], [1, 1, 288, 64], [1, 1, 256, 64], [5, 5, 48, 64],
- [1, 1, 256, 48], [3, 3, 96, 96], [1, 1, 192, 32], [
- 1, 1, 192, 64
- ], [1, 1, 192, 48], [3, 3, 64, 192], [1, 1, 64,
- 64], [1, 1, 24, 64]]
- out_sizes = [[4, 5, 5, 128], [4, 8, 8, 384], [4, 8, 8, 384], [4, 8, 8, 192], [
- 4, 8, 8, 384
- ], [4, 8, 8, 320], [4, 8, 8, 448], [4, 8, 8, 384], [4, 8, 8, 384], [
- 4, 8, 8, 192
- ], [4, 8, 8, 448], [4, 8, 8, 320], [4, 8, 8, 192], [4, 17, 17, 192], [
- 4, 17, 17, 192
- ], [4, 8, 8, 320], [4, 17, 17, 128], [4, 17, 17, 224], [4, 17, 17, 256], [
- 4, 17, 17, 256
- ], [4, 17, 17, 192], [4, 17, 17, 96], [4, 17, 17, 224], [4, 17, 17, 224], [
- 4, 17, 17, 192
- ], [4, 17, 17, 192], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 192], [
- 4, 17, 17, 160
- ], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 160], [4, 17, 17, 192], [
- 4, 17, 17, 256
- ], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 128], [4, 17, 17, 128], [
- 4, 17, 17, 128
- ], [4, 17, 17, 320], [4, 17, 17, 96], [4, 17, 17, 384], [4, 35, 35, 96], [
- 4, 35, 35, 64
- ], [4, 35, 35, 64], [4, 35, 35, 64], [4, 35, 35, 48], [4, 35, 35, 96],
- [4, 35, 35, 32], [4, 35, 35, 64], [4, 35, 35, 48],
- [4, 71, 71, 192], [4, 73, 73, 64], [4, 147, 147, 64]]
- strides = [
- 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1
- ]
- # Shrink sizes to make the test faster
- for i in input_sizes:
- i[3] //= shrink
- for f in filter_sizes:
- f[2] //= shrink
- f[3] //= shrink
- for o in out_sizes:
- o[3] //= shrink
- # pylint: disable=invalid-name
- VALID = "VALID"
- SAME = "SAME"
- # pylint: enable=invalid-name
- paddings = [
- SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
- VALID, SAME, SAME, VALID, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
- SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
- SAME, SAME, SAME, SAME, SAME, VALID, VALID, SAME, SAME, SAME, SAME, SAME,
- SAME, SAME, SAME, SAME, VALID, VALID, VALID
- ]
- for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
- paddings):
- yield i, f, o, s, p
-
-
-def GetTestConfigs():
- """Get all the valid tests configs to run.
-
- Returns:
- all the valid test configs as tuples of data_format and use_gpu.
- """
- test_configs = [("NCHW", True), ("NHWC", True)]
- return test_configs
-
-
-class FusedConv2DBiasActivationTest(test.TestCase):
-
- def _DtypesToTest(self, use_gpu):
- return [dtypes.float32]
-
- def _FilterFormatsToTest(self, use_gpu):
- return ["HWIO", "OIHW"]
-
- def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, bias,
- strides, padding, activation_mode, data_format,
- filter_format, dtype):
- """Verifies the output values of the convolution function.
-
- Args:
- tensor_in_sizes: Input tensor dimensions in
- [batch, input_rows, input_cols, input_depth].
- filter_in_sizes: Filter tensor dimensions in
- [kernel_rows, kernel_cols, input_depth, output_depth].
- bias: 1-D bias tensor of length output_depth.
- strides: Stride: [col_stride, row_stride]
- padding: Padding type.
- activation_mode: Activation mode.
- data_format: Format of the data tensors.
- filter_format: Filter format to use for the fused convolution.
- dtype: Data type for inputs and outputs.
- Returns:
- Symbolic tensor value and reference value that can be used to
- execute the computation and verify the results.
- """
- input_size = np.prod(tensor_in_sizes)
- filter_size = np.prod(filter_in_sizes)
- bias_size = filter_in_sizes[-1] # equals to output depth
- # Initializes the input tensor with array containing incrementing
- # numbers from 1.
- x1 = [f * 1.0 for f in range(1, input_size + 1)]
- x2 = [f * 1.0 for f in range(1, filter_size + 1)]
- # This is to guarantee that there is always negative values after
- # bias add so that we can test whether relu works correctly.
- x3 = bias
- with self.test_session(use_gpu=True):
- t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype)
- t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype)
- fused_t2 = t2
- if filter_format == "OIHW":
- fused_t2 = HwioToOihw(t2)
- t3 = constant_op.constant(x3, shape=[bias_size], dtype=dtype)
- strides = [1] + strides + [1]
- if data_format == "NCHW":
- t1 = test_util.NHWCToNCHW(t1)
- strides = test_util.NHWCToNCHW(strides)
- output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- t1,
- fused_t2,
- t3,
- strides=strides,
- padding=padding,
- data_format=data_format,
- filter_format=filter_format,
- activation_mode=activation_mode)
- ref_conv_output = nn_ops.conv2d(
- t1, t2, strides=strides, padding=padding, data_format=data_format)
- ref_bias_output = nn_ops.bias_add(
- ref_conv_output, t3, data_format=data_format)
- ref_output = nn_ops.relu(ref_bias_output)
- if data_format == "NCHW":
- output = test_util.NCHWToNHWC(output)
- ref_output = test_util.NCHWToNHWC(ref_output)
-
- return output, ref_output
-
- def _CompareFwdValues(self, tensor_in_sizes, filter_in_sizes, conv_strides,
- padding):
- """Verifies that CPU and GPU produce the same values.
-
- Args:
- tensor_in_sizes: Input tensor dimensions in
- [batch, input_rows, input_cols, input_depth].
- filter_in_sizes: Filter tensor dimensions in
- [kernel_rows, kernel_cols, input_depth, output_depth].
- conv_strides: [row_stride, col_stride] for the convolution;
- padding: Padding type.
- """
- x1 = np.random.rand(*tensor_in_sizes).astype(np.float32)
- x2 = np.random.rand(*filter_in_sizes).astype(np.float32)
- x3 = np.random.rand(*[filter_in_sizes[-1]]).astype(np.float32)
-
- def _SetupVal(data_format, use_gpu):
- with self.test_session(use_gpu=use_gpu):
- t1 = constant_op.constant(x1, shape=tensor_in_sizes)
- t2 = constant_op.constant(x2, shape=filter_in_sizes)
- t3 = constant_op.constant(x3, shape=[filter_in_sizes[-1]])
- strides = [1] + conv_strides + [1]
- if data_format == "NCHW":
- t1 = test_util.NHWCToNCHW(t1)
- strides = test_util.NHWCToNCHW(strides)
- output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- t1,
- t2,
- t3,
- strides=strides,
- padding=padding,
- data_format=data_format,
- activation_mode="Relu")
-
- if data_format == "NCHW":
- output = test_util.NCHWToNHWC(output)
- return output
-
- tensors = []
- for (data_format, use_gpu) in GetTestConfigs():
- tensors.append(_SetupVal(data_format, use_gpu))
- with self.test_session() as sess:
- values = sess.run(tensors)
- for i in range(1, len(values)):
- self.assertAllClose(values[0], values[i], rtol=1e-5, atol=1e-5)
-
- def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, bias, strides,
- padding):
- tensors = []
- ref_tensors = []
- for (data_format, use_gpu) in GetTestConfigs():
- for dtype in self._DtypesToTest(use_gpu):
- for filter_format in self._FilterFormatsToTest(use_gpu):
- result, expected = self._SetupValuesForDevice(
- tensor_in_sizes, filter_in_sizes, bias, strides, padding, "Relu",
- data_format, filter_format, dtype)
- tensors.append(result)
- ref_tensors.append(expected)
- with self.test_session() as sess:
- values = sess.run(tensors)
- ref_values = sess.run(ref_tensors)
- for i in range(len(tensors)):
- conv = tensors[i]
- value = values[i]
- ref_value = ref_values[i]
- tf_logging.info("expected = ", ref_value)
- tf_logging.info("actual = ", value)
- tol = 1e-5
- if value.dtype == np.float16:
- tol = 1e-3
- self.assertAllClose(
- np.ravel(ref_value), np.ravel(value), atol=tol, rtol=tol)
- self.assertShapeEqual(value, conv)
-
- def testConv2D1x1Filter(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping Conv2D1x1Filter test.")
- return
- # expected_output = [
- # 0.0, 0.0, 0.0, 21.0, 0.0, 0.0, 57.0, 0.0, 0.0, 93.0, 41.0, 0.0, 129.0,
- # 86.0, 43.0, 165.0, 131.0, 97.0
- # ]
- medians = [-45.0, -130.0, -215.0]
- self._VerifyValues(
- tensor_in_sizes=[1, 2, 3, 3],
- filter_in_sizes=[1, 1, 3, 3],
- bias=medians,
- strides=[1, 1],
- padding="VALID")
-
- def testConv2DEmpty(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping Conv2DEmpty test.")
- return
- # expected_output = []
- self._VerifyValues(
- tensor_in_sizes=[0, 2, 3, 3],
- filter_in_sizes=[1, 1, 3, 3],
- bias=[0.0, 0.0, 0.0],
- strides=[1, 1],
- padding="VALID")
-
- def testConv2D2x2Filter(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping Conv2D2x2Filter test.")
- return
- # expected_output = [0.0, 0.0, 0.0, 401.0, 533.0, 665.0]
- self._VerifyValues(
- tensor_in_sizes=[1, 2, 3, 3],
- filter_in_sizes=[2, 2, 3, 3],
- bias=[-2500.0, -2500.0, -2500.0],
- strides=[1, 1],
- padding="VALID")
-
- def testConv2D1x2Filter(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping Conv2D1x2Filter test.")
- return
- # expected_output = [
- # 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 190.0, 265.0, 340.0, 343.0, 436.0, 529.0
- # ]
- self._VerifyValues(
- tensor_in_sizes=[1, 2, 3, 3],
- filter_in_sizes=[1, 2, 3, 3],
- bias=[-500.0, -500.0, -500.0],
- strides=[1, 1],
- padding="VALID")
-
- def testConv2D2x2FilterStride2(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping Conv2D2x2FilterStride2 test.")
- return
- # expected_output = [0.0, 67.0, 163.0]
- self._VerifyValues(
- tensor_in_sizes=[1, 2, 3, 3],
- filter_in_sizes=[2, 2, 3, 3],
- bias=[-2300.0, -2300.0, -2300.0],
- strides=[2, 2],
- padding="VALID")
-
- def testConv2D2x2FilterStride2Same(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping Conv2D2x2FilterStride2Same test.")
- return
- # expected_output = [0.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0]
- self._VerifyValues(
- tensor_in_sizes=[1, 2, 3, 3],
- filter_in_sizes=[2, 2, 3, 3],
- bias=[-2300.0, -1000.0, -1000.0],
- strides=[2, 2],
- padding="SAME")
-
- def testConv2D2x2FilterStride1x2(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping Conv2D2x2FilterStride1x2 test.")
- return
- # expected_output = [0.0, 0.0, 8.0, 28.0, 48.0, 68.0]
- self._VerifyValues(
- tensor_in_sizes=[1, 3, 6, 1],
- filter_in_sizes=[2, 2, 1, 1],
- bias=[-90.0],
- strides=[1, 2],
- padding="VALID")
-
- def testConv2DKernelSmallerThanStrideValid(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping Conv2DKernelSmallerThanStrideValid test.")
- return
- # expected_output = [0, 0, 175, 205]
- self._VerifyValues(
- tensor_in_sizes=[1, 7, 7, 1],
- filter_in_sizes=[2, 2, 1, 1],
- bias=[-100.0],
- strides=[3, 3],
- padding="VALID")
-
- def testConv2DKernelSmallerThanStrideSame(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping Conv2DKernelSmallerThanStrideSame test.")
- return
- # expected = [0, 0, 2, 4]
- self._VerifyValues(
- tensor_in_sizes=[1, 3, 3, 1],
- filter_in_sizes=[1, 1, 1, 1],
- bias=[-5.0],
- strides=[2, 2],
- padding="SAME")
-
- # expected = [0, 0, 4, 6]
- self._VerifyValues(
- tensor_in_sizes=[1, 4, 4, 1],
- filter_in_sizes=[1, 1, 1, 1],
- bias=[-5.0],
- strides=[2, 2],
- padding="SAME")
-
- # expected = [4, 0, 1, 0]
- self._VerifyValues(
- tensor_in_sizes=[1, 4, 4, 1],
- filter_in_sizes=[2, 2, 1, 1],
- bias=[-40.0],
- strides=[3, 3],
- padding="SAME")
-
- def testConv2DKernelSizeMatchesInputSize(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping Conv2DKernelSizeMatchesInputSize test.")
- return
- # expected = [0, 5]
- self._VerifyValues(
- tensor_in_sizes=[1, 2, 2, 1],
- filter_in_sizes=[2, 2, 1, 2],
- bias=[-50.0, -55.0],
- strides=[1, 1],
- padding="VALID")
-
- # expected = [0, 2, 282, 322]
- self._VerifyValues(
- tensor_in_sizes=[1, 8, 8, 1],
- filter_in_sizes=[2, 2, 1, 1],
- bias=[-200.0],
- strides=[4, 4],
- padding="SAME")
-
- def testShapeFunctionEdgeCases(self):
- # All shapes unknown.
- c1 = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- array_ops.placeholder(dtypes.float32),
- array_ops.placeholder(dtypes.float32),
- array_ops.placeholder(dtypes.float32),
- strides=[1, 1, 1, 1],
- padding="SAME",
- activation_mode="Relu")
- self.assertEqual([None, None, None, None], c1.get_shape().as_list())
-
- # Incorrect input shape.
- with self.assertRaises(ValueError):
- fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- array_ops.placeholder(dtypes.float32, shape=[1, 3]),
- array_ops.placeholder(dtypes.float32),
- array_ops.placeholder(dtypes.float32),
- strides=[1, 1, 1, 1],
- padding="SAME",
- activation_mode="Relu")
-
- # Incorrect filter shape.
- with self.assertRaises(ValueError):
- fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- array_ops.placeholder(dtypes.float32),
- array_ops.placeholder(dtypes.float32, shape=[1, 3]),
- array_ops.placeholder(dtypes.float32),
- strides=[1, 1, 1, 1],
- padding="SAME",
- activation_mode="Relu")
-
- # Depth mismatch.
- with self.assertRaises(ValueError):
- fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
- array_ops.placeholder(dtypes.float32, shape=[4, 4, 2, 2]),
- array_ops.placeholder(dtypes.float32),
- strides=[1, 1, 1, 1],
- padding="SAME",
- activation_mode="Relu")
-
- def testOpEdgeCases(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping OpEdgeCases tests.")
- return
- with self.test_session() as sess:
- # Illegal strides.
- with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
- "Convolutional strides are not supported in "
- "the batch or depth dimensions."):
- sess.run(
- fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- array_ops.placeholder(dtypes.float32),
- array_ops.placeholder(dtypes.float32),
- array_ops.placeholder(dtypes.float32),
- strides=[2, 1, 1, 1],
- padding="SAME",
- activation_mode="Relu"))
- with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
- "Convolutional strides are not supported in "
- "the batch or depth dimensions."):
- sess.run(
- fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- array_ops.placeholder(dtypes.float32),
- array_ops.placeholder(dtypes.float32),
- array_ops.placeholder(dtypes.float32),
- strides=[1, 1, 1, 2],
- padding="SAME",
- activation_mode="Relu"))
-
- # Illegal activation mode.
- with self.assertRaisesRegexp(ValueError,
- "Op passed string 'Tanh' not in:"):
- sess.run(
- fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- array_ops.placeholder(dtypes.float32),
- array_ops.placeholder(dtypes.float32),
- array_ops.placeholder(dtypes.float32),
- strides=[1, 1, 1, 1],
- padding="SAME",
- activation_mode="Tanh"))
-
- # Filter larger than input.
- with self.assertRaisesRegexp(ValueError, "Negative dimension size"):
- sess.run(
- fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
- array_ops.placeholder(dtypes.float32, shape=[20, 21, 3, 2]),
- array_ops.placeholder(dtypes.float32, shape=[2]),
- strides=[1, 1, 1, 1],
- padding="VALID",
- activation_mode="Relu"))
- with self.assertRaisesRegexp(ValueError, "Negative dimension size"):
- sess.run(
- fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
- array_ops.placeholder(dtypes.float32, shape=[21, 20, 3, 2]),
- array_ops.placeholder(dtypes.float32, shape=[2]),
- strides=[1, 1, 1, 1],
- padding="VALID",
- activation_mode="Relu"))
-
-
-def GetInceptionFwdTest(input_size, filter_size, stride, padding,
- gpu_only=True):
-
- def Test(self):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping InceptionFwd %s", (input_size, filter_size,
- stride, padding))
- return
- tf_logging.info("Testing InceptionFwd %s", (input_size, filter_size, stride,
- padding))
- self._CompareFwdValues(input_size, filter_size, [stride, stride], padding)
-
- return Test
-
-
-def CalculateConvolvedOutputDim(input_dim, filter_dim, stride, padding_type):
- """Calculates the size of an output dimension of a strided convolution.
-
- Given the sizes of the corresponding dimension of the input and filter shapes,
- and the stride and padding_types, calculates the size of the output dimension.
- This function can be called separately for each input dimension.
-
- Args:
- input_dim: An `int` specifying the size of the input dimension.
- filter_dim: An `int` specifying the size of the filter dimension.
- stride: An `int` specifying the step size of the convolution along the
- input dimension.
- padding_type: either 'VALID' or 'SAME'.
-
- Returns:
- The size of the output dimension.
- """
- if padding_type == "VALID":
- return (input_dim - filter_dim + stride) // stride
- else: # padding_type == 'SAME'
- return (input_dim + stride - 1) // stride
-
-
-def NchwVectCToNchw(in_tensor):
- # [N, C / 4, H, W, 4] => [N, C / 4, 4, H, W] == [N, C, H, W]
- t = array_ops.transpose(in_tensor, [0, 1, 4, 2, 3])
- n = in_tensor.shape.dims[0].value
- c = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value
- h = in_tensor.shape.dims[2].value
- w = in_tensor.shape.dims[3].value
- return array_ops.reshape(t, [n, c, h, w])
-
-
-def OihwVectIToHwio(in_tensor):
- # [O, I / 4, H, W, 4] => [O, I / 4, 4, H, W] == [O, I, H, W]
- t = array_ops.transpose(in_tensor, [2, 3, 1, 4, 0])
- o = in_tensor.shape.dims[0].value
- i = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value
- h = in_tensor.shape.dims[2].value
- w = in_tensor.shape.dims[3].value
- return array_ops.reshape(t, [h, w, i, o])
-
-
-def NchwToNchwVectC(in_tensor):
- n, c, h, w = in_tensor.shape.as_list()
- assert c % 4 == 0
- t = array_ops.reshape(in_tensor, [n, c // 4, 4, h, w])
- return array_ops.transpose(t, [0, 1, 3, 4, 2])
-
-
-def HwioToOihw(in_tensor):
- return array_ops.transpose(in_tensor, [3, 2, 0, 1])
-
-
-def SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel,
- padding, strides, side_input_scale,
- side_input, biases, apply_relu):
- """Simulates the int8 fused 2-D convolution op using separate float ops.
-
- The arguments and return values have the same format, meanings and
- restrictions as the actual op.
- Args:
- conv_input_scale: A scalar 'float'.
- conv_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout.
- kernel: A `Tensor` of type `qint8` in OIHW_VECT_I layout.
- padding: A `string` from: `"SAME", "VALID"`.
- strides: A list of `ints`.
- side_input_scale: A scalar 'float'.
- side_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout.
- biases: A `Tensor` of type `float32` in NCHW layout.
- apply_relu: A boolean to specify whether to apply "Relu" activation function
- that clips outputs to the range [0, 127], or "None" activation that clips
- to the range [-128, 127].
- Returns:
- A `Tensor` of type `qint8` in NCHW_VECT_C layout.
- """
- conv_result = nn_ops.conv2d(
- NchwVectCToNchw(gen_array_ops.dequantize(conv_input, -128, 127)),
- OihwVectIToHwio(gen_array_ops.dequantize(kernel, -128, 127)),
- strides=strides,
- padding=padding,
- data_format="NCHW") * conv_input_scale
-
- conv_and_side_inputs = conv_result + side_input_scale * NchwVectCToNchw(
- gen_array_ops.dequantize(side_input, -128, 127))
-
- output = nn_ops.bias_add(conv_and_side_inputs, biases, data_format="NCHW")
- if apply_relu:
- output = nn_ops.relu(output)
-
- result, _, _ = gen_array_ops.quantize_v2(
- NchwToNchwVectC(output), -128, 127, dtypes.qint8)
- return result
-
-
-class FusedConvInt8Tests(test.TestCase):
- _test_params = [
- {
- "batch_size": 1,
- "input_channels": 4,
- "output_channels": 4,
- "input_height": 8,
- "input_width": 8,
- "filter_height": 6,
- "filter_width": 6,
- "vertical_stride": 2,
- "horizontal_stride": 2,
- "conv_input_scale": 0.002,
- "side_input_scale": 0.0,
- "bias_scale": 1,
- "padding_type": "SAME"
- },
- {
- "batch_size": 1,
- "input_channels": 4,
- "output_channels": 4,
- "input_height": 6,
- "input_width": 6,
- "filter_height": 6,
- "filter_width": 6,
- "vertical_stride": 2,
- "horizontal_stride": 2,
- "conv_input_scale": 0.002,
- "side_input_scale": 0.0,
- "bias_scale": 1,
- "padding_type": "SAME"
- },
- {
- "batch_size": 2,
- "input_channels": 8,
- "output_channels": 16,
- "input_height": 8,
- "input_width": 8,
- "filter_height": 3,
- "filter_width": 3,
- "vertical_stride": 2,
- "horizontal_stride": 2,
- "conv_input_scale": 0.002,
- "side_input_scale": 0.0,
- "bias_scale": 1,
- "padding_type": "VALID"
- },
- {
- "batch_size": 2,
- "input_channels": 8,
- "output_channels": 16,
- "input_height": 8,
- "input_width": 8,
- "filter_height": 3,
- "filter_width": 3,
- "vertical_stride": 2,
- "horizontal_stride": 2,
- "conv_input_scale": 0.002,
- "side_input_scale": 0.0,
- "bias_scale": 1,
- "padding_type": "SAME"
- },
- {
- "batch_size": 2,
- "input_channels": 8,
- "output_channels": 16,
- "input_height": 8,
- "input_width": 8,
- "filter_height": 3,
- "filter_width": 3,
- "vertical_stride": 2,
- "horizontal_stride": 2,
- "conv_input_scale": 0.002,
- "side_input_scale": 0.5,
- "bias_scale": 1,
- "padding_type": "VALID"
- },
- {
- "batch_size": 2,
- "input_channels": 16,
- "output_channels": 16,
- "input_height": 9,
- "input_width": 9,
- "filter_height": 3,
- "filter_width": 3,
- "vertical_stride": 1,
- "horizontal_stride": 1,
- "conv_input_scale": 0.001,
- "side_input_scale": 0.5,
- "bias_scale": 1,
- "padding_type": "SAME"
- },
- {
- "batch_size": 3,
- "input_channels": 8,
- "output_channels": 8,
- "input_height": 9,
- "input_width": 9,
- "filter_height": 5,
- "filter_width": 5,
- "vertical_stride": 1,
- "horizontal_stride": 1,
- "conv_input_scale": 0.001,
- "side_input_scale": 0.5,
- "bias_scale": 1,
- "padding_type": "SAME"
- },
- {
- "batch_size": 3,
- "input_channels": 8,
- "output_channels": 8,
- "input_height": 9,
- "input_width": 9,
- "filter_height": 7,
- "filter_width": 1,
- "vertical_stride": 2,
- "horizontal_stride": 1,
- "conv_input_scale": 0.002,
- "side_input_scale": 0.5,
- "bias_scale": 1,
- "padding_type": "SAME"
- },
- {
- "batch_size": 3,
- "input_channels": 8,
- "output_channels": 8,
- "input_height": 9,
- "input_width": 9,
- "filter_height": 1,
- "filter_width": 7,
- "vertical_stride": 1,
- "horizontal_stride": 1,
- "conv_input_scale": 0.002,
- "side_input_scale": 0.5,
- "bias_scale": 1,
- "padding_type": "SAME"
- },
- ]
-
- def runTest(self, test_param, apply_relu):
- batch_size = test_param["batch_size"]
- input_channels = test_param["input_channels"]
- output_channels = test_param["output_channels"]
- input_height = test_param["input_height"]
- input_width = test_param["input_width"]
- filter_height = test_param["filter_height"]
- filter_width = test_param["filter_width"]
- vertical_stride = test_param["vertical_stride"]
- horizontal_stride = test_param["horizontal_stride"]
- conv_input_scale = test_param["conv_input_scale"]
- side_input_scale = test_param["side_input_scale"]
- bias_scale = test_param["bias_scale"]
- padding_type = test_param["padding_type"]
-
- conv_input, _, _ = gen_array_ops.quantize_v2(
- random_ops.random_uniform(
- [batch_size, input_channels // 4, input_height, input_width, 4],
- minval=-0.0,
- maxval=1.0,
- dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8)
-
- kernel, _, _ = gen_array_ops.quantize_v2(
- random_ops.random_uniform(
- [
- output_channels, input_channels // 4, filter_height,
- filter_width, 4
- ],
- minval=-1.0,
- maxval=1.0,
- dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8)
-
- output_height = CalculateConvolvedOutputDim(input_height, filter_height,
- vertical_stride, padding_type)
- output_width = CalculateConvolvedOutputDim(input_width, filter_width,
- horizontal_stride, padding_type)
- tf_logging.info("output_height=", output_height, ", output_width=",
- output_width)
-
- side_input, _, _ = gen_array_ops.quantize_v2(
- random_ops.random_uniform(
- [batch_size, output_channels // 4, output_height, output_width, 4],
- minval=0.0,
- maxval=1.0,
- dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8)
-
- biases = random_ops.random_uniform(
- [output_channels],
- minval=-10 * bias_scale,
- maxval=20 * bias_scale,
- dtype=dtypes.float32)
-
- strides = [1, 1, vertical_stride, horizontal_stride]
-
- actual = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- conv_input,
- kernel,
- biases,
- strides=strides,
- padding=padding_type,
- conv_input_scale=conv_input_scale,
- side_input_scale=side_input_scale,
- side_input=side_input,
- activation_mode="Relu" if apply_relu else "None",
- data_format="NCHW_VECT_C",
- filter_format="OIHW_VECT_I")
- expected = SimulateFusedConv2dBiasActivationInt8(
- conv_input_scale, conv_input, kernel, padding_type, strides,
- side_input_scale, side_input, biases, apply_relu)
- with self.test_session(use_gpu=True) as sess:
- actual_y, expected_y = sess.run([actual, expected])
- tf_logging.info("actual_y = ", actual_y)
- tf_logging.info("expected_y = ", expected_y)
- self.assertTrue(np.array_equal(actual_y, expected_y))
+# Instantiate the two test suites from test_base, mixing in test.TestCase as
+# the test framework.
+class FusedConv2DBiasActivationTest(test_base.FusedConv2DBiasActivationTest,
+ test.TestCase):
+ pass
- def testFusedConvInt8(self):
- if not test.is_gpu_available(
- cuda_only=True, min_cuda_compute_capability=(6, 1)):
- tf_logging.info("int8 test skipped because not run with --config=cuda or "
- "no GPUs with compute capability >= 6.1 are available.")
- return
- for apply_relu in [True, False]:
- for test_param in self._test_params:
- self.runTest(test_param, apply_relu)
+class FusedConvInt8Tests(test_base.FusedConvInt8Tests, test.TestCase):
+ pass
-if __name__ == "__main__":
- for index, (input_size_, filter_size_, output_size_, stride_,
- padding_) in enumerate(GetShrunkInceptionShapes()):
- setattr(FusedConv2DBiasActivationTest, "testInceptionFwd_" + str(index),
- GetInceptionFwdTest(input_size_, filter_size_, stride_, padding_))
- # TODO(b/35359731)
- # Fwd, BckInput, and BackFilter to test that for certain input parameter
- # set, winograd nonfused algorithm will be excluded from conv autotune. If
- # in such case, winograd nonfused algorithm is added as one option of the
- # conv autotune, and cuDNN version is smaller than 7, the following tests
- # will fail.
- ishape = [1, 400, 400, 1]
- fshape = [1, 1, 1, 256]
- oshape = [1, 400, 400, 256]
- setattr(FusedConv2DBiasActivationTest,
- "testInceptionFwd_No_Winograd_Nonfused",
- GetInceptionFwdTest(ishape, fshape, 1, "SAME", gpu_only=True))
+if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test_base.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test_base.py
new file mode 100644
index 0000000000..35fc65e4ba
--- /dev/null
+++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test_base.py
@@ -0,0 +1,945 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Provides test suites that can be run to test fused convolutions.
+
+Each of the two test suites in this module, FusedConv2DBiasActivationTest and
+FusedConvInt8Tests, should be "instantiated" by declaring a class which inherits
+from the FusedConv test and a class that provides the standard test.TestCase
+API.
+
+See e.g. fused_conv2d_bias_activation_op_test.py in this folder.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import numpy as np
+
+from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+
+def _GetShrunkInceptionShapes(shrink=10):
+ """Iterator for smaller versions of convolution shapes in 2015 Inception.
+
+ Relative to inception, each depth value is `depth // shrink`.
+
+ Args:
+ shrink: Factor to shrink each depth value by relative to Inception.
+
+ Yields:
+ Tuple (input_size, filter_size, out_size, stride, padding), the convolution
+ parameters of Inception layers.
+ """
+ input_sizes = [[4, 5, 5, 1248], [4, 8, 8, 384], [4, 8, 8, 384], [
+ 4, 8, 8, 2048
+ ], [4, 8, 8, 448], [4, 8, 8, 2048], [4, 8, 8, 2048], [4, 8, 8, 2048], [
+ 4, 8, 8, 1760
+ ], [4, 8, 8, 1760], [4, 8, 8, 1760], [4, 8, 8, 1760], [4, 17, 17, 192], [
+ 4, 17, 17, 192
+ ], [4, 17, 17, 1248], [4, 17, 17, 128], [4, 17, 17, 1248], [4, 17, 17, 224], [
+ 4, 17, 17, 192
+ ], [4, 17, 17, 192], [4, 17, 17, 1216], [4, 17, 17, 1216], [4, 17, 17, 224], [
+ 4, 17, 17, 192
+ ], [4, 17, 17, 192], [4, 17, 17, 1152], [4, 17, 17, 1152], [4, 17, 17, 192], [
+ 4, 17, 17, 160
+ ], [4, 17, 17, 1152], [4, 17, 17, 1024], [4, 17, 17, 128], [4, 17, 17, 1024],
+ [4, 17, 17, 128], [4, 17, 17, 1024], [4, 17, 17, 128], [
+ 4, 17, 17, 768
+ ], [4, 17, 17, 128], [4, 17, 17, 128], [4, 17, 17, 768],
+ [4, 17, 17, 768], [4, 35, 35, 96], [4, 35, 35, 288], [
+ 4, 35, 35, 64
+ ], [4, 35, 35, 288], [4, 35, 35, 256], [4, 35, 35, 48], [
+ 4, 35, 35, 256
+ ], [4, 35, 35, 96], [4, 35, 35, 192], [4, 35, 35, 192], [
+ 4, 35, 35, 192
+ ], [4, 73, 73, 64], [4, 73, 73, 64], [4, 147, 147, 24]]
+ filter_sizes = [[1, 1, 1248, 128], [1, 3, 384, 384], [3, 1, 384, 384], [
+ 1, 1, 2048, 192
+ ], [3, 3, 448, 384], [1, 1, 2048, 320], [1, 1, 2048, 448], [1, 1, 2048, 384],
+ [1, 1, 1760, 384], [1, 1, 1760, 192], [1, 1, 1760, 448], [
+ 1, 1, 1760, 320
+ ], [3, 3, 192, 192], [3, 3, 192, 192], [1, 1, 1248, 192], [
+ 3, 3, 128, 320
+ ], [1, 1, 1248, 128], [1, 3, 224, 224], [3, 1, 192, 256], [
+ 1, 3, 192, 256
+ ], [1, 1, 1216, 192], [1, 1, 1216, 96], [3, 1, 224, 224], [
+ 3, 3, 192, 224
+ ], [1, 3, 192, 192], [1, 1, 1152, 192], [1, 1, 1152, 128], [
+ 3, 1, 192, 192
+ ], [3, 3, 160, 192], [1, 1, 1152, 160], [1, 1, 1024, 128], [
+ 1, 3, 128, 192
+ ], [1, 1, 1024, 160], [3, 1, 128, 192], [1, 1, 1024, 256], [
+ 3, 1, 128, 128
+ ], [1, 1, 768, 192], [1, 3, 128, 128], [3, 3, 128, 128], [
+ 1, 1, 768, 128
+ ], [1, 1, 768, 320], [3, 3, 96, 96], [3, 3, 288, 384], [
+ 3, 3, 64, 96
+ ], [1, 1, 288, 64], [1, 1, 256, 64], [5, 5, 48, 64],
+ [1, 1, 256, 48], [3, 3, 96, 96], [1, 1, 192, 32], [
+ 1, 1, 192, 64
+ ], [1, 1, 192, 48], [3, 3, 64, 192], [1, 1, 64,
+ 64], [1, 1, 24, 64]]
+ out_sizes = [[4, 5, 5, 128], [4, 8, 8, 384], [4, 8, 8, 384], [4, 8, 8, 192], [
+ 4, 8, 8, 384
+ ], [4, 8, 8, 320], [4, 8, 8, 448], [4, 8, 8, 384], [4, 8, 8, 384], [
+ 4, 8, 8, 192
+ ], [4, 8, 8, 448], [4, 8, 8, 320], [4, 8, 8, 192], [4, 17, 17, 192], [
+ 4, 17, 17, 192
+ ], [4, 8, 8, 320], [4, 17, 17, 128], [4, 17, 17, 224], [4, 17, 17, 256], [
+ 4, 17, 17, 256
+ ], [4, 17, 17, 192], [4, 17, 17, 96], [4, 17, 17, 224], [4, 17, 17, 224], [
+ 4, 17, 17, 192
+ ], [4, 17, 17, 192], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 192], [
+ 4, 17, 17, 160
+ ], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 160], [4, 17, 17, 192], [
+ 4, 17, 17, 256
+ ], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 128], [4, 17, 17, 128], [
+ 4, 17, 17, 128
+ ], [4, 17, 17, 320], [4, 17, 17, 96], [4, 17, 17, 384], [4, 35, 35, 96], [
+ 4, 35, 35, 64
+ ], [4, 35, 35, 64], [4, 35, 35, 64], [4, 35, 35, 48], [4, 35, 35, 96],
+ [4, 35, 35, 32], [4, 35, 35, 64], [4, 35, 35, 48],
+ [4, 71, 71, 192], [4, 73, 73, 64], [4, 147, 147, 64]]
+ strides = [
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1
+ ]
+ # Shrink sizes to make the test faster
+ for i in input_sizes:
+ i[3] //= shrink
+ for f in filter_sizes:
+ f[2] //= shrink
+ f[3] //= shrink
+ for o in out_sizes:
+ o[3] //= shrink
+ # pylint: disable=invalid-name
+ VALID = "VALID"
+ SAME = "SAME"
+ # pylint: enable=invalid-name
+ paddings = [
+ SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
+ VALID, SAME, SAME, VALID, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
+ SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
+ SAME, SAME, SAME, SAME, SAME, VALID, VALID, SAME, SAME, SAME, SAME, SAME,
+ SAME, SAME, SAME, SAME, VALID, VALID, VALID
+ ]
+ for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
+ paddings):
+ yield i, f, o, s, p
+
+
+def _GetTestConfigs():
+ """Get all the valid tests configs to run.
+
+ Returns:
+ all the valid test configs as tuples of data_format and use_gpu.
+ """
+ test_configs = [("NCHW", True), ("NHWC", True)]
+ return test_configs
+
+
+def _IotaNdF32Constant(dim_sizes):
+
+ def MakeList(dims):
+ if len(dims) == 1:
+ return [float(1 + f) for f in range(dims[0])]
+ return [MakeList(dims[1:]) for _ in range(dims[0])]
+
+ return constant_op.constant(MakeList(dim_sizes), dtype=dtypes.float32)
+
+
+def _GetInceptionFwdTest(input_size,
+ filter_size,
+ stride,
+ padding,
+ gpu_only=True):
+
+ def Test(self):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping InceptionFwd %s",
+ (input_size, filter_size, stride, padding))
+ return
+ tf_logging.info("Testing InceptionFwd %s",
+ (input_size, filter_size, stride, padding))
+ self.CompareFwdValues(input_size, filter_size, [stride, stride], padding)
+
+ return Test
+
+
+class FusedConv2DBiasActivationTest(object):
+
+ @contextlib.contextmanager
+ def test_scope(self): # pylint: disable=invalid-name
+ """Can be overridden in base classes to provide a test scope."""
+ yield
+
+ def _DtypesToTest(self, use_gpu):
+ return [dtypes.float32]
+
+ def _FilterFormatsToTest(self, use_gpu):
+ return ["HWIO", "OIHW"]
+
+ def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, bias,
+ strides, padding, activation_mode, data_format,
+ filter_format, dtype):
+ """Verifies the output values of the convolution function.
+
+ Args:
+ tensor_in_sizes: Input tensor dimensions in
+ [batch, input_rows, input_cols, input_depth].
+ filter_in_sizes: Filter tensor dimensions in
+ [kernel_rows, kernel_cols, input_depth, output_depth].
+ bias: 1-D bias tensor of length output_depth.
+ strides: Stride: [col_stride, row_stride]
+ padding: Padding type.
+ activation_mode: Activation mode.
+ data_format: Format of the data tensors.
+ filter_format: Filter format to use for the fused convolution.
+ dtype: Data type for inputs and outputs.
+ Returns:
+ Symbolic tensor value and reference value that can be used to
+ execute the computation and verify the results.
+ """
+ input_size = np.prod(tensor_in_sizes)
+ filter_size = np.prod(filter_in_sizes)
+ bias_size = filter_in_sizes[-1] # equals to output depth
+ # Initializes the input tensor with array containing incrementing
+ # numbers from 1.
+ x1 = [f * 1.0 for f in range(1, input_size + 1)]
+ x2 = [f * 1.0 for f in range(1, filter_size + 1)]
+ # This is to guarantee that there are always negative values after
+ # bias add so that we can test whether relu works correctly.
+ x3 = bias
+ with self.cached_session(use_gpu=True), self.test_scope():
+ t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype)
+ t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype)
+ fused_t2 = t2
+ if filter_format == "OIHW":
+ fused_t2 = _HwioToOihw(t2)
+ t3 = constant_op.constant(x3, shape=[bias_size], dtype=dtype)
+ strides = [1] + strides + [1]
+ if data_format == "NCHW":
+ t1 = test_util.NHWCToNCHW(t1)
+ strides = test_util.NHWCToNCHW(strides)
+ output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ t1,
+ fused_t2,
+ t3,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ filter_format=filter_format,
+ activation_mode=activation_mode)
+ ref_conv_output = nn_ops.conv2d(
+ t1, t2, strides=strides, padding=padding, data_format=data_format)
+ ref_bias_output = nn_ops.bias_add(
+ ref_conv_output, t3, data_format=data_format)
+ ref_output = nn_ops.relu(ref_bias_output)
+ if data_format == "NCHW":
+ output = test_util.NCHWToNHWC(output)
+ ref_output = test_util.NCHWToNHWC(ref_output)
+
+ return output, ref_output
+
+ def CompareFwdValues(self, tensor_in_sizes, filter_in_sizes, conv_strides,
+ padding):
+ """Verifies that CPU and GPU produce the same values.
+
+ Args:
+ tensor_in_sizes: Input tensor dimensions in
+ [batch, input_rows, input_cols, input_depth].
+ filter_in_sizes: Filter tensor dimensions in
+ [kernel_rows, kernel_cols, input_depth, output_depth].
+ conv_strides: [row_stride, col_stride] for the convolution;
+ padding: Padding type.
+ """
+ x1 = np.random.rand(*tensor_in_sizes).astype(np.float32)
+ x2 = np.random.rand(*filter_in_sizes).astype(np.float32)
+ x3 = np.random.rand(*[filter_in_sizes[-1]]).astype(np.float32)
+
+ def _SetupVal(data_format, use_gpu):
+ with self.cached_session(use_gpu=use_gpu), self.test_scope():
+ t1 = constant_op.constant(x1, shape=tensor_in_sizes)
+ t2 = constant_op.constant(x2, shape=filter_in_sizes)
+ t3 = constant_op.constant(x3, shape=[filter_in_sizes[-1]])
+ strides = [1] + conv_strides + [1]
+ if data_format == "NCHW":
+ t1 = test_util.NHWCToNCHW(t1)
+ strides = test_util.NHWCToNCHW(strides)
+ output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ t1,
+ t2,
+ t3,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ activation_mode="Relu")
+
+ if data_format == "NCHW":
+ output = test_util.NCHWToNHWC(output)
+ return output
+
+ tensors = []
+ for (data_format, use_gpu) in _GetTestConfigs():
+ tensors.append(_SetupVal(data_format, use_gpu))
+ with self.cached_session() as sess, self.test_scope():
+ values = sess.run(tensors)
+ for i in range(1, len(values)):
+ self.assertAllClose(values[0], values[i], rtol=1e-3, atol=1e-3)
+
+ def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, bias, strides,
+ padding):
+ tensors = []
+ ref_tensors = []
+ for (data_format, use_gpu) in _GetTestConfigs():
+ for dtype in self._DtypesToTest(use_gpu):
+ for filter_format in self._FilterFormatsToTest(use_gpu):
+ result, expected = self._SetupValuesForDevice(
+ tensor_in_sizes, filter_in_sizes, bias, strides, padding, "Relu",
+ data_format, filter_format, dtype)
+ tensors.append(result)
+ ref_tensors.append(expected)
+ with self.cached_session() as sess, self.test_scope():
+ values = sess.run(tensors)
+ ref_values = sess.run(ref_tensors)
+ for i in range(len(tensors)):
+ conv = tensors[i]
+ value = values[i]
+ ref_value = ref_values[i]
+ tf_logging.info("expected = %s", ref_value)
+ tf_logging.info("actual = %s", value)
+ tol = 1e-5
+ if value.dtype == np.float16:
+ tol = 1e-3
+ self.assertAllClose(
+ np.ravel(ref_value), np.ravel(value), atol=tol, rtol=tol)
+ self.assertShapeEqual(value, conv)
+
+ def testConv2D1x1Filter(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2D1x1Filter test.")
+ return
+ # expected_output = [
+ # 0.0, 0.0, 0.0, 21.0, 0.0, 0.0, 57.0, 0.0, 0.0, 93.0, 41.0, 0.0, 129.0,
+ # 86.0, 43.0, 165.0, 131.0, 97.0
+ # ]
+ medians = [-45.0, -130.0, -215.0]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[1, 1, 3, 3],
+ bias=medians,
+ strides=[1, 1],
+ padding="VALID")
+
+ def testConv2DEmpty(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2DEmpty test.")
+ return
+ # expected_output = []
+ self._VerifyValues(
+ tensor_in_sizes=[0, 2, 3, 3],
+ filter_in_sizes=[1, 1, 3, 3],
+ bias=[0.0, 0.0, 0.0],
+ strides=[1, 1],
+ padding="VALID")
+
+ def testConv2D2x2Filter(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2D2x2Filter test.")
+ return
+ # expected_output = [0.0, 0.0, 0.0, 401.0, 533.0, 665.0]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[2, 2, 3, 3],
+ bias=[-2500.0, -2500.0, -2500.0],
+ strides=[1, 1],
+ padding="VALID")
+
+ def testConv2D1x2Filter(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2D1x2Filter test.")
+ return
+ # expected_output = [
+ # 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 190.0, 265.0, 340.0, 343.0, 436.0, 529.0
+ # ]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[1, 2, 3, 3],
+ bias=[-500.0, -500.0, -500.0],
+ strides=[1, 1],
+ padding="VALID")
+
+ def testConv2D2x2FilterStride2(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2D2x2FilterStride2 test.")
+ return
+ # expected_output = [0.0, 67.0, 163.0]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[2, 2, 3, 3],
+ bias=[-2300.0, -2300.0, -2300.0],
+ strides=[2, 2],
+ padding="VALID")
+
+ def testConv2D2x2FilterStride2Same(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2D2x2FilterStride2Same test.")
+ return
+ # expected_output = [0.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[2, 2, 3, 3],
+ bias=[-2300.0, -1000.0, -1000.0],
+ strides=[2, 2],
+ padding="SAME")
+
+ def testConv2D2x2FilterStride1x2(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2D2x2FilterStride1x2 test.")
+ return
+ # expected_output = [0.0, 0.0, 8.0, 28.0, 48.0, 68.0]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 3, 6, 1],
+ filter_in_sizes=[2, 2, 1, 1],
+ bias=[-90.0],
+ strides=[1, 2],
+ padding="VALID")
+
+ def testConv2DKernelSmallerThanStrideValid(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2DKernelSmallerThanStrideValid test.")
+ return
+ # expected_output = [0, 0, 175, 205]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 7, 7, 1],
+ filter_in_sizes=[2, 2, 1, 1],
+ bias=[-100.0],
+ strides=[3, 3],
+ padding="VALID")
+
+ def testConv2DKernelSmallerThanStrideSame(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2DKernelSmallerThanStrideSame test.")
+ return
+ # expected = [0, 0, 2, 4]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 3, 3, 1],
+ filter_in_sizes=[1, 1, 1, 1],
+ bias=[-5.0],
+ strides=[2, 2],
+ padding="SAME")
+
+ # expected = [0, 0, 4, 6]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 4, 4, 1],
+ filter_in_sizes=[1, 1, 1, 1],
+ bias=[-5.0],
+ strides=[2, 2],
+ padding="SAME")
+
+ # expected = [4, 0, 1, 0]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 4, 4, 1],
+ filter_in_sizes=[2, 2, 1, 1],
+ bias=[-40.0],
+ strides=[3, 3],
+ padding="SAME")
+
+ def testConv2DKernelSizeMatchesInputSize(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2DKernelSizeMatchesInputSize test.")
+ return
+ # expected = [0, 5]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 2, 2, 1],
+ filter_in_sizes=[2, 2, 1, 2],
+ bias=[-50.0, -55.0],
+ strides=[1, 1],
+ padding="VALID")
+
+ # expected = [0, 2, 282, 322]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 8, 8, 1],
+ filter_in_sizes=[2, 2, 1, 1],
+ bias=[-200.0],
+ strides=[4, 4],
+ padding="SAME")
+
+ def testShapeFunctionEdgeCases(self):
+ # All shapes unknown.
+ c1 = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.float32),
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ activation_mode="Relu")
+ self.assertEqual([None, None, None, None], c1.get_shape().as_list())
+
+ # Incorrect input shape.
+ with self.assertRaises(ValueError):
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ array_ops.placeholder(dtypes.float32, shape=[1, 3]),
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.float32),
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ activation_mode="Relu")
+
+ # Incorrect filter shape.
+ with self.assertRaises(ValueError):
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.float32, shape=[1, 3]),
+ array_ops.placeholder(dtypes.float32),
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ activation_mode="Relu")
+
+ # Depth mismatch.
+ with self.assertRaises(ValueError):
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
+ array_ops.placeholder(dtypes.float32, shape=[4, 4, 2, 2]),
+ array_ops.placeholder(dtypes.float32),
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ activation_mode="Relu")
+
+ def testOpEdgeCases(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping OpEdgeCases tests.")
+ return
+ with self.cached_session() as sess, self.test_scope():
+ # Illegal strides.
+ with self.assertRaisesRegexp(
+ errors_impl.UnimplementedError,
+ ".*strides.*in the batch and depth dimensions"):
+ sess.run(
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ _IotaNdF32Constant([1, 1, 1, 1]),
+ _IotaNdF32Constant([1, 1, 1, 1]),
+ _IotaNdF32Constant([1]),
+ strides=[2, 1, 1, 1],
+ padding="SAME",
+ activation_mode="Relu"))
+ with self.assertRaisesRegexp(
+ errors_impl.UnimplementedError,
+ ".*strides.*in the batch and depth dimensions"):
+ sess.run(
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ _IotaNdF32Constant([1, 1, 1, 1]),
+ _IotaNdF32Constant([1, 1, 1, 1]),
+ _IotaNdF32Constant([1]),
+ strides=[1, 1, 1, 2],
+ padding="SAME",
+ activation_mode="Relu"))
+
+ # Illegal activation mode.
+ with self.assertRaisesRegexp(ValueError,
+ "Op passed string 'Tanh' not in:"):
+ sess.run(
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ _IotaNdF32Constant([1, 1, 1, 1]),
+ _IotaNdF32Constant([1, 1, 1, 1]),
+ _IotaNdF32Constant([1]),
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ activation_mode="Tanh"))
+
+ # Filter larger than input.
+ with self.assertRaisesRegexp(ValueError, "Negative dimension size"):
+ sess.run(
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ _IotaNdF32Constant([32, 20, 20, 3]),
+ _IotaNdF32Constant([20, 21, 3, 2]),
+ _IotaNdF32Constant([2]),
+ strides=[1, 1, 1, 1],
+ padding="VALID",
+ activation_mode="Relu"))
+ with self.assertRaisesRegexp(ValueError, "Negative dimension size"):
+ sess.run(
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ _IotaNdF32Constant([32, 20, 20, 3]),
+ _IotaNdF32Constant([21, 20, 3, 2]),
+ _IotaNdF32Constant([2]),
+ strides=[1, 1, 1, 1],
+ padding="VALID",
+ activation_mode="Relu"))
+
+
+# Add InceptionFwd tests to FusedConv2DBiasActivationTest.
+for index, (input_size_, filter_size_, output_size_, stride_,
+ padding_) in enumerate(_GetShrunkInceptionShapes()):
+ setattr(FusedConv2DBiasActivationTest, "testInceptionFwd_" + str(index),
+ _GetInceptionFwdTest(input_size_, filter_size_, stride_, padding_))
+
+# TODO(b/35359731)
+# Fwd, BckInput, and BackFilter to test that for certain input parameter
+# set, winograd nonfused algorithm will be excluded from conv autotune. If
+# in such case, winograd nonfused algorithm is added as one option of the
+# conv autotune, and cuDNN version is smaller than 7, the following tests
+# will fail.
+ishape = [1, 400, 400, 1]
+fshape = [1, 1, 1, 256]
+oshape = [1, 400, 400, 256]
+setattr(FusedConv2DBiasActivationTest, "testInceptionFwd_No_Winograd_Nonfused",
+ _GetInceptionFwdTest(ishape, fshape, 1, "SAME", gpu_only=True))
+
+
+def _CalculateConvolvedOutputDim(input_dim, filter_dim, stride, padding_type):
+ """Calculates the size of an output dimension of a strided convolution.
+
+ Given the sizes of the corresponding dimension of the input and filter shapes,
+ and the stride and padding_types, calculates the size of the output dimension.
+ This function can be called separately for each input dimension.
+
+ Args:
+ input_dim: An `int` specifying the size of the input dimension.
+ filter_dim: An `int` specifying the size of the filter dimension.
+ stride: An `int` specifying the step size of the convolution along the
+ input dimension.
+ padding_type: either 'VALID' or 'SAME'.
+
+ Returns:
+ The size of the output dimension.
+ """
+ if padding_type == "VALID":
+ return (input_dim - filter_dim + stride) // stride
+ else: # padding_type == 'SAME'
+ return (input_dim + stride - 1) // stride
+
+
+def _NchwVectCToNchw(in_tensor):
+ # [N, C / 4, H, W, 4] => [N, C / 4, 4, H, W] == [N, C, H, W]
+ t = array_ops.transpose(in_tensor, [0, 1, 4, 2, 3])
+ n = in_tensor.shape.dims[0].value
+ c = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value
+ h = in_tensor.shape.dims[2].value
+ w = in_tensor.shape.dims[3].value
+ return array_ops.reshape(t, [n, c, h, w])
+
+
+def _OihwVectIToHwio(in_tensor):
+ # [O, I / 4, H, W, 4] => [O, I / 4, 4, H, W] == [O, I, H, W]
+ t = array_ops.transpose(in_tensor, [2, 3, 1, 4, 0])
+ o = in_tensor.shape.dims[0].value
+ i = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value
+ h = in_tensor.shape.dims[2].value
+ w = in_tensor.shape.dims[3].value
+ return array_ops.reshape(t, [h, w, i, o])
+
+
+def _NchwToNchwVectC(in_tensor):
+ n, c, h, w = in_tensor.shape.as_list()
+ assert c % 4 == 0
+ t = array_ops.reshape(in_tensor, [n, c // 4, 4, h, w])
+ return array_ops.transpose(t, [0, 1, 3, 4, 2])
+
+
+def _HwioToOihw(in_tensor):
+ return array_ops.transpose(in_tensor, [3, 2, 0, 1])
+
+
+def _SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel,
+ padding, strides, side_input_scale,
+ side_input, biases, apply_relu):
+ """Simulates the int8 fused 2-D convolution op using separate float ops.
+
+ The arguments and return values have the same format, meanings and
+ restrictions as the actual op.
+ Args:
+ conv_input_scale: A scalar 'float'.
+ conv_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout.
+ kernel: A `Tensor` of type `qint8` in OIHW_VECT_I layout.
+ padding: A `string` from: `"SAME", "VALID"`.
+ strides: A list of `ints`.
+ side_input_scale: A scalar 'float'.
+ side_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout.
+ biases: A `Tensor` of type `float32` in NCHW layout.
+ apply_relu: A boolean to specify whether to apply "Relu" activation function
+ that clips outputs to the range [0, 127], or "None" activation that clips
+ to the range [-128, 127].
+ Returns:
+ A `Tensor` of type `qint8` in NCHW_VECT_C layout.
+ """
+ conv_result = nn_ops.conv2d(
+ _NchwVectCToNchw(gen_array_ops.dequantize(conv_input, -128, 127)),
+ _OihwVectIToHwio(gen_array_ops.dequantize(kernel, -128, 127)),
+ strides=strides,
+ padding=padding,
+ data_format="NCHW") * conv_input_scale
+
+ conv_and_side_inputs = conv_result + side_input_scale * _NchwVectCToNchw(
+ gen_array_ops.dequantize(side_input, -128, 127))
+
+ output = nn_ops.bias_add(conv_and_side_inputs, biases, data_format="NCHW")
+ if apply_relu:
+ output = nn_ops.relu(output)
+
+ result, _, _ = gen_array_ops.quantize_v2(
+ _NchwToNchwVectC(output), -128, 127, dtypes.qint8)
+ return result
+
+
+# TODO(b/114580749): XLA:CPU/GPU don't support int8 at the moment, so this test
+# doesn't currently use XLA.
+class FusedConvInt8Tests(object):
+ _test_params = [
+ {
+ "batch_size": 1,
+ "input_channels": 4,
+ "output_channels": 4,
+ "input_height": 8,
+ "input_width": 8,
+ "filter_height": 6,
+ "filter_width": 6,
+ "vertical_stride": 2,
+ "horizontal_stride": 2,
+ "conv_input_scale": 0.002,
+ "side_input_scale": 0.0,
+ "bias_scale": 1,
+ "padding_type": "SAME"
+ },
+ {
+ "batch_size": 1,
+ "input_channels": 4,
+ "output_channels": 4,
+ "input_height": 6,
+ "input_width": 6,
+ "filter_height": 6,
+ "filter_width": 6,
+ "vertical_stride": 2,
+ "horizontal_stride": 2,
+ "conv_input_scale": 0.002,
+ "side_input_scale": 0.0,
+ "bias_scale": 1,
+ "padding_type": "SAME"
+ },
+ {
+ "batch_size": 2,
+ "input_channels": 8,
+ "output_channels": 16,
+ "input_height": 8,
+ "input_width": 8,
+ "filter_height": 3,
+ "filter_width": 3,
+ "vertical_stride": 2,
+ "horizontal_stride": 2,
+ "conv_input_scale": 0.002,
+ "side_input_scale": 0.0,
+ "bias_scale": 1,
+ "padding_type": "VALID"
+ },
+ {
+ "batch_size": 2,
+ "input_channels": 8,
+ "output_channels": 16,
+ "input_height": 8,
+ "input_width": 8,
+ "filter_height": 3,
+ "filter_width": 3,
+ "vertical_stride": 2,
+ "horizontal_stride": 2,
+ "conv_input_scale": 0.002,
+ "side_input_scale": 0.0,
+ "bias_scale": 1,
+ "padding_type": "SAME"
+ },
+ {
+ "batch_size": 2,
+ "input_channels": 8,
+ "output_channels": 16,
+ "input_height": 8,
+ "input_width": 8,
+ "filter_height": 3,
+ "filter_width": 3,
+ "vertical_stride": 2,
+ "horizontal_stride": 2,
+ "conv_input_scale": 0.002,
+ "side_input_scale": 0.5,
+ "bias_scale": 1,
+ "padding_type": "VALID"
+ },
+ {
+ "batch_size": 2,
+ "input_channels": 16,
+ "output_channels": 16,
+ "input_height": 9,
+ "input_width": 9,
+ "filter_height": 3,
+ "filter_width": 3,
+ "vertical_stride": 1,
+ "horizontal_stride": 1,
+ "conv_input_scale": 0.001,
+ "side_input_scale": 0.5,
+ "bias_scale": 1,
+ "padding_type": "SAME"
+ },
+ {
+ "batch_size": 3,
+ "input_channels": 8,
+ "output_channels": 8,
+ "input_height": 9,
+ "input_width": 9,
+ "filter_height": 5,
+ "filter_width": 5,
+ "vertical_stride": 1,
+ "horizontal_stride": 1,
+ "conv_input_scale": 0.001,
+ "side_input_scale": 0.5,
+ "bias_scale": 1,
+ "padding_type": "SAME"
+ },
+ {
+ "batch_size": 3,
+ "input_channels": 8,
+ "output_channels": 8,
+ "input_height": 9,
+ "input_width": 9,
+ "filter_height": 7,
+ "filter_width": 1,
+ "vertical_stride": 2,
+ "horizontal_stride": 1,
+ "conv_input_scale": 0.002,
+ "side_input_scale": 0.5,
+ "bias_scale": 1,
+ "padding_type": "SAME"
+ },
+ {
+ "batch_size": 3,
+ "input_channels": 8,
+ "output_channels": 8,
+ "input_height": 9,
+ "input_width": 9,
+ "filter_height": 1,
+ "filter_width": 7,
+ "vertical_stride": 1,
+ "horizontal_stride": 1,
+ "conv_input_scale": 0.002,
+ "side_input_scale": 0.5,
+ "bias_scale": 1,
+ "padding_type": "SAME"
+ },
+ ]
+
+ @contextlib.contextmanager
+ def test_scope(self): # pylint: disable=invalid-name
+ """Can be overridden in base classes to provide a test scope."""
+ yield
+
+ def runTest(self, test_param, apply_relu):
+ batch_size = test_param["batch_size"]
+ input_channels = test_param["input_channels"]
+ output_channels = test_param["output_channels"]
+ input_height = test_param["input_height"]
+ input_width = test_param["input_width"]
+ filter_height = test_param["filter_height"]
+ filter_width = test_param["filter_width"]
+ vertical_stride = test_param["vertical_stride"]
+ horizontal_stride = test_param["horizontal_stride"]
+ conv_input_scale = test_param["conv_input_scale"]
+ side_input_scale = test_param["side_input_scale"]
+ bias_scale = test_param["bias_scale"]
+ padding_type = test_param["padding_type"]
+
+ with self.cached_session(use_gpu=True) as sess, self.test_scope():
+ conv_input, _, _ = gen_array_ops.quantize_v2(
+ random_ops.random_uniform(
+ [batch_size, input_channels // 4, input_height, input_width, 4],
+ minval=-0.0,
+ maxval=1.0,
+ dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8)
+
+ kernel, _, _ = gen_array_ops.quantize_v2(
+ random_ops.random_uniform([
+ output_channels, input_channels // 4, filter_height, filter_width,
+ 4
+ ],
+ minval=-1.0,
+ maxval=1.0,
+ dtype=dtypes.float32), -1.0, 1.0,
+ dtypes.qint8)
+
+ output_height = _CalculateConvolvedOutputDim(
+ input_height, filter_height, vertical_stride, padding_type)
+ output_width = _CalculateConvolvedOutputDim(
+ input_width, filter_width, horizontal_stride, padding_type)
+ tf_logging.info("output_height=%s, output_width=%s", output_height,
+ output_width)
+
+ side_input, _, _ = gen_array_ops.quantize_v2(
+ random_ops.random_uniform([
+ batch_size, output_channels // 4, output_height, output_width, 4
+ ],
+ minval=0.0,
+ maxval=1.0,
+ dtype=dtypes.float32), -1.0, 1.0,
+ dtypes.qint8)
+
+ biases = random_ops.random_uniform([output_channels],
+ minval=-10 * bias_scale,
+ maxval=20 * bias_scale,
+ dtype=dtypes.float32)
+
+ strides = [1, 1, vertical_stride, horizontal_stride]
+
+ actual = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ conv_input,
+ kernel,
+ biases,
+ strides=strides,
+ padding=padding_type,
+ conv_input_scale=conv_input_scale,
+ side_input_scale=side_input_scale,
+ side_input=side_input,
+ activation_mode="Relu" if apply_relu else "None",
+ data_format="NCHW_VECT_C",
+ filter_format="OIHW_VECT_I")
+
+ expected = _SimulateFusedConv2dBiasActivationInt8(
+ conv_input_scale, conv_input, kernel, padding_type, strides,
+ side_input_scale, side_input, biases, apply_relu)
+
+ actual_y, expected_y = sess.run([actual, expected])
+ self.assertAllClose(actual_y, expected_y, rtol=0, atol=1)
+
+ def testFusedConvInt8(self):
+ if not test.is_gpu_available(
+ cuda_only=True, min_cuda_compute_capability=(6, 1)):
+ tf_logging.info("int8 test skipped because not run with --config=cuda or "
+ "no GPUs with compute capability >= 6.1 are available.")
+ return
+ for apply_relu in [True, False]:
+ for test_param in self._test_params:
+ self.runTest(test_param, apply_relu)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py
index d389748374..8bc4db8424 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py
@@ -773,9 +773,9 @@ def mutual_information_penalty(
structured_generator_inputs: A list of Tensors representing the random noise
that must have high mutual information with the generator output. List
length should match `predicted_distributions`.
- predicted_distributions: A list of tf.Distributions. Predicted by the
- recognizer, and used to evaluate the likelihood of the structured noise.
- List length should match `structured_generator_inputs`.
+ predicted_distributions: A list of `tfp.distributions.Distribution`s.
+ Predicted by the recognizer, and used to evaluate the likelihood of the
+ structured noise. List length should match `structured_generator_inputs`.
weights: Optional `Tensor` whose rank is either 0, or the same dimensions as
`structured_generator_inputs`.
scope: The scope for the operations performed in computing the loss.
diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py
index a462b68e28..b9ac1bf151 100644
--- a/tensorflow/contrib/gan/python/namedtuples.py
+++ b/tensorflow/contrib/gan/python/namedtuples.py
@@ -91,9 +91,9 @@ class InfoGANModel(
structured_generator_inputs: A list of Tensors representing the random noise
that must have high mutual information with the generator output. List
length should match `predicted_distributions`.
- predicted_distributions: A list of tf.Distributions. Predicted by the
- recognizer, and used to evaluate the likelihood of the structured noise.
- List length should match `structured_generator_inputs`.
+ predicted_distributions: A list of `tfp.distributions.Distribution`s.
+ Predicted by the recognizer, and used to evaluate the likelihood of the
+ structured noise. List length should match `structured_generator_inputs`.
discriminator_and_aux_fn: The original discriminator function that returns
a tuple of (logits, `predicted_distributions`).
"""
diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py
index 58f348034f..64d6706199 100644
--- a/tensorflow/contrib/gan/python/train_test.py
+++ b/tensorflow/contrib/gan/python/train_test.py
@@ -399,7 +399,7 @@ class StarGANModelTest(test.TestCase):
target_tensor = train._generate_stargan_random_domain_target(
batch_size, domain_numbers)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
targets = sess.run(target_tensor)
self.assertTupleEqual((batch_size, domain_numbers), targets.shape)
for target in targets:
@@ -676,7 +676,7 @@ class GANLossTest(test.TestCase, parameterized.TestCase):
self.assertIsInstance(model_loss, namedtuples.GANLoss)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc
index 726f74c7b7..bb06f1c41c 100644
--- a/tensorflow/contrib/gdr/gdr_memory_manager.cc
+++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc
@@ -138,6 +138,8 @@ class GdrMemoryManager : public RemoteMemoryManager {
Device* device, DeviceContext* device_context, bool on_host,
StatusCallback done) override;
+ static void RegMemVisitors();
+
protected:
Status CreateEndpoint(const string& host, const string& port,
RdmaEndpointPtr& endpoint);
@@ -183,35 +185,51 @@ class GdrMemoryManager : public RemoteMemoryManager {
TF_DISALLOW_COPY_AND_ASSIGN(GdrMemoryManager);
};
-// TODO(byronyi): remove this class and its registration when the default
-// cpu_allocator() returns visitable allocator, or cpu_allocator() is no
-// longer in use.
-class BFCGdrAllocator : public BFCAllocator {
- public:
- BFCGdrAllocator()
- : BFCAllocator(new BasicCPUAllocator(port::kNUMANoAffinity), 1LL << 36,
- true, "cpu_gdr_bfc") {}
-};
-class BFCGdrAllocatorFactory : public AllocatorFactory {
- public:
- Allocator* CreateAllocator() override { return new BFCGdrAllocator; }
-
- virtual SubAllocator* CreateSubAllocator(int numa_node) {
- return new BasicCPUAllocator(numa_node);
- }
-};
-
-REGISTER_MEM_ALLOCATOR("BFCGdrAllocator", 102, BFCGdrAllocatorFactory);
-
GdrMemoryManager::GdrMemoryManager(const string& host, const string& port)
: host_(host),
port_(port),
listening_(nullptr, EndpointDeleter),
stopped_(true),
- next_key_(0) {}
+ next_key_(0) {
+ static std::once_flag flag;
+ std::call_once(flag, []() { RegMemVisitors(); });
+}
GdrMemoryManager::~GdrMemoryManager() { close(epfd_); }
+/*static*/ void GdrMemoryManager::RegMemVisitors() {
+ SubAllocator::Visitor alloc_visitor = [](void* ptr, int numa_node,
+ size_t num_bytes) {
+ GdrMemoryManager::Singleton().InsertMemoryRegion(
+ ptr, num_bytes, strings::StrCat("CPU:", numa_node));
+ };
+ SubAllocator::Visitor free_visitor = [](void* ptr, int numa_node,
+ size_t num_bytes) {
+ GdrMemoryManager::Singleton().EvictMemoryRegion(ptr, num_bytes);
+ };
+ ProcessState::singleton()->AddCPUAllocVisitor(alloc_visitor);
+ ProcessState::singleton()->AddCPUFreeVisitor(free_visitor);
+
+#if GOOGLE_CUDA
+ if (IsGDRAvailable()) {
+ int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + 1;
+
+ // Note we don't free allocated GPU memory so there is no free visitor
+ SubAllocator::Visitor cuda_alloc_visitor = [](void* ptr, int gpu_id,
+ size_t num_bytes) {
+ RdmaMemoryMgr::Singleton().InsertMemoryRegion(
+ ptr, num_bytes, strings::StrCat("GPU:", gpu_id));
+ };
+ GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id,
+ cuda_alloc_visitor);
+ GPUProcessState::singleton()->AddCUDAHostAllocVisitor(bus_id,
+ alloc_visitor);
+ GPUProcessState::singleton()->AddCUDAHostFreeVisitor(bus_id, free_visitor);
+ LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
+ }
+#endif // GOOGLE_CUDA
+}
+
Status GdrMemoryManager::Init() {
epfd_ = epoll_create1(0);
if (epfd_ == -1) {
@@ -271,48 +289,6 @@ Status GdrMemoryManager::Init() {
"cannot add server to epoll");
}
- Allocator* allocators[] = {
-#if GOOGLE_CUDA
- GPUProcessState::singleton()->GetCUDAHostAllocator(0),
-#endif // GOOGLE_CUDA
- ProcessState::singleton()->GetCPUAllocator(0),
- cpu_allocator(),
- };
-
- using namespace std::placeholders;
- VisitableAllocator::Visitor alloc_visitor =
- std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2);
- VisitableAllocator::Visitor free_visitor =
- std::bind(&GdrMemoryManager::EvictMemoryRegion, this, _1, _2);
-
- std::set<Allocator*> instrumented_;
-
- // Host memory allocators
- for (Allocator* allocator : allocators) {
- auto* visitable_allocator = dynamic_cast<VisitableAllocator*>(allocator);
- CHECK(visitable_allocator)
- << "is not visitable for instrumentation" << allocator->Name();
- // Make sure we don't instrument the same allocator twice
- if (instrumented_.find(allocator) == std::end(instrumented_)) {
- visitable_allocator->AddAllocVisitor(alloc_visitor);
- visitable_allocator->AddFreeVisitor(free_visitor);
- instrumented_.insert(allocator);
- LOG(INFO) << "Instrumenting CPU allocator " << allocator->Name();
- }
- }
-
-#if GOOGLE_CUDA
- VisitableAllocator::Visitor cuda_alloc_visitor =
- std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2);
- if (IsGDRAvailable()) {
- // Note we don't free allocated GPU memory so there is no free visitor
- int32_t bus_id = TryToReadNumaNode(listening_->verbs->device) + 1;
- GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id,
- cuda_alloc_visitor);
- LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
- }
-#endif // GOOGLE_CUDA
-
return Status::OK();
}
diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py
index 97f38c923f..0ebcdc2688 100644
--- a/tensorflow/contrib/graph_editor/tests/transform_test.py
+++ b/tensorflow/contrib/graph_editor/tests/transform_test.py
@@ -214,7 +214,7 @@ class TransformTest(test.TestCase):
def test_graph_replace_gradients(self):
ops.reset_default_graph()
- w = variables.Variable(0.0, name="w")
+ w = variables.VariableV1(0.0, name="w")
y = math_ops.multiply(math_ops.multiply(w, w, name="mul1"), w, name="mul2")
g = gradients_impl.gradients(y, w, name="grad")[0]
diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
index fed8a771cc..27aed091c2 100644
--- a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
+++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
@@ -233,7 +233,7 @@ class GridRNNCellTest(test.TestCase):
([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]])))
def testGrid2LSTMCellWithRelu(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -261,7 +261,7 @@ class GridRNNCellTest(test.TestCase):
"""
def testGrid2BasicRNNCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([2, 2])
@@ -292,7 +292,7 @@ class GridRNNCellTest(test.TestCase):
[[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
def testGrid2BasicRNNCellTied(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([2, 2])
@@ -323,7 +323,7 @@ class GridRNNCellTest(test.TestCase):
[[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
def testGrid2BasicRNNCellWithRelu(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -348,7 +348,7 @@ class GridRNNCellTest(test.TestCase):
"""
def testGrid1LSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)) as root_scope:
x = array_ops.zeros([1, 3])
@@ -410,7 +410,7 @@ class GridRNNCellTest(test.TestCase):
"""
def testGrid3LSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -455,7 +455,7 @@ class GridRNNCellTest(test.TestCase):
"""
def testGridRNNEdgeCasesLikeRelu(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([3, 2])
@@ -481,7 +481,7 @@ class GridRNNCellTest(test.TestCase):
self.assertAllClose(res_g, ([[0, 0], [0, 0], [0.5, 0.5]],))
def testGridRNNEdgeCasesNoOutput(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -541,7 +541,7 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(out[0].get_shape()[1], num_units)
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
@@ -581,7 +581,7 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(out[0].get_shape()[1], num_units)
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
@@ -623,7 +623,7 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(out[0].get_shape()[1], num_units)
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
@@ -663,7 +663,7 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(out[0].get_shape(), (3, num_units))
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
@@ -700,7 +700,7 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(out[0].get_shape()[1], num_units)
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((3, input_size))
@@ -715,7 +715,7 @@ class GridRNNCellTest(test.TestCase):
def testGrid2LSTMCellLegacy(self):
"""Test for legacy case (when state_is_tuple=False)."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
diff --git a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py
index d796e43d87..f7f1189bb9 100644
--- a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py
+++ b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py
@@ -51,7 +51,7 @@ class SequenceFileDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(num_repeats): # Dataset is repeated.
for i in range(25): # 25 records.
diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py
index 6e0e628655..bf398b838d 100644
--- a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py
+++ b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py
@@ -19,14 +19,14 @@ from __future__ import print_function
from tensorflow.contrib.hadoop.python.ops import gen_dataset_ops
from tensorflow.contrib.hadoop.python.ops import hadoop_op_loader # pylint: disable=unused-import
-from tensorflow.python.data.ops.dataset_ops import Dataset
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-class SequenceFileDataset(Dataset):
+class SequenceFileDataset(dataset_ops.DatasetSource):
"""A Sequence File Dataset that reads the sequence file."""
def __init__(self, filenames):
diff --git a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
index 9ed017592a..f44edaa14c 100644
--- a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
+++ b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class InputPipelineOpsTest(test.TestCase):
def testObtainNext(self):
- with self.test_session():
+ with self.cached_session():
var = state_ops.variable_op([], dtypes.int64)
state_ops.assign(var, -1).op.run()
c = constant_op.constant(["a", "b"])
@@ -45,7 +45,7 @@ class InputPipelineOpsTest(test.TestCase):
def testSeekNext(self):
string_list = ["a", "b", "c"]
- with self.test_session() as session:
+ with self.cached_session() as session:
elem = input_pipeline_ops.seek_next(string_list)
session.run([variables.global_variables_initializer()])
self.assertEqual(b"a", session.run(elem))
@@ -65,7 +65,7 @@ class InputPipelineOpsTest(test.TestCase):
def testSeekNextLimitEpochs(self):
string_list = ["a", "b", "c"]
- with self.test_session() as session:
+ with self.cached_session() as session:
elem = input_pipeline_ops.seek_next(string_list, num_epochs=1)
session.run([
variables.local_variables_initializer(),
@@ -75,7 +75,7 @@ class InputPipelineOpsTest(test.TestCase):
def testSeekNextLimitEpochsThree(self):
string_list = ["a", "b", "c"]
- with self.test_session() as session:
+ with self.cached_session() as session:
elem = input_pipeline_ops.seek_next(string_list, num_epochs=3)
session.run([
variables.local_variables_initializer(),
diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py
index 621911876f..08ebcdb544 100644
--- a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py
+++ b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py
@@ -54,7 +54,7 @@ class KafkaDatasetTest(test.TestCase):
init_batch_op = iterator.make_initializer(batch_dataset)
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Basic test: read from topic 0.
sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1})
for i in range(5):
diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py
index a1624614d1..7129f09e8b 100644
--- a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py
+++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py
@@ -17,15 +17,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.kafka.python.ops import kafka_op_loader # pylint: disable=unused-import
from tensorflow.contrib.kafka.python.ops import gen_dataset_ops
-from tensorflow.python.data.ops.dataset_ops import Dataset
+from tensorflow.contrib.kafka.python.ops import kafka_op_loader # pylint: disable=unused-import
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-class KafkaDataset(Dataset):
+class KafkaDataset(dataset_ops.DatasetSource):
"""A Kafka Dataset that consumes the message.
"""
diff --git a/tensorflow/contrib/kernel_methods/python/losses_test.py b/tensorflow/contrib/kernel_methods/python/losses_test.py
index 72507539f8..4d5cc24ce0 100644
--- a/tensorflow/contrib/kernel_methods/python/losses_test.py
+++ b/tensorflow/contrib/kernel_methods/python/losses_test.py
@@ -32,7 +32,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInvalidLogitsShape(self):
"""An error is raised when logits have invalid shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2,))
labels = constant_op.constant([0, 1])
with self.assertRaises(ValueError):
@@ -40,7 +40,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInvalidLabelsShape(self):
"""An error is raised when labels have invalid shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0], shape=(1, 1, 2))
with self.assertRaises(ValueError):
@@ -48,7 +48,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInvalidWeightsShape(self):
"""An error is raised when weights have invalid shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0], shape=(2,))
weights = constant_op.constant([1.5, 0.2], shape=(2, 1, 1))
@@ -57,7 +57,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInvalidLabelsDtype(self):
"""An error is raised when labels have invalid shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0], dtype=dtypes.float32)
with self.assertRaises(ValueError):
@@ -65,7 +65,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testNoneWeightRaisesValueError(self):
"""An error is raised when weights are None."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0])
with self.assertRaises(ValueError):
@@ -73,7 +73,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInconsistentLabelsAndWeightsShapesSameRank(self):
"""Error raised when weights and labels have same ranks, different sizes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1, 4.1], shape=(3, 1))
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
weights = constant_op.constant([1.1, 2.0], shape=(2, 1))
@@ -82,7 +82,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInconsistentLabelsAndWeightsShapesDifferentRank(self):
"""Error raised when weights and labels have different ranks and sizes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0], shape=(2, 1))
weights = constant_op.constant([1.1, 2.0, 2.8], shape=(3,))
@@ -91,7 +91,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testOutOfRangeLabels(self):
"""An error is raised when labels are not in [0, num_classes)."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0],
[0.5, 1.8, -1.0]])
labels = constant_op.constant([1, 0, 4])
@@ -101,7 +101,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testZeroLossInt32Labels(self):
"""Loss is 0 if true class logits sufficiently higher than other classes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0],
[0.5, 1.8, -1.0]])
labels = constant_op.constant([0, 2, 1], dtype=dtypes.int32)
@@ -110,7 +110,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testZeroLossInt64Labels(self):
"""Loss is 0 if true class logits sufficiently higher than other classes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[2.1, -0.4, -1.0], [1.4, 2.8, 4.0],
[-0.5, 0.8, -1.0]])
labels = constant_op.constant([0, 2, 1], dtype=dtypes.int64)
@@ -130,7 +130,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
]
for batch_size, num_classes in logits_shapes:
- with self.test_session():
+ with self.cached_session():
logits = array_ops.placeholder(
dtypes.float32, shape=(batch_size, num_classes))
labels = array_ops.placeholder(dtypes.int32, shape=(batch_size,))
@@ -140,7 +140,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testCorrectPredictionsSomeClassesInsideMargin(self):
"""Loss is > 0 even if true class logits are higher than other classes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.2, -1.4, 0.8], [1.4, 1.8, 4.0],
[1.5, 1.8, -1.0]])
labels = constant_op.constant([0, 2, 1])
@@ -150,7 +150,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testIncorrectPredictions(self):
"""Loss is >0 when an incorrect class has higher logits than true class."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[2.6, 0.4, 0.8], [1.4, 0.8, -1.0],
[0.5, -1.8, 2.0]])
labels = constant_op.constant([1, 0, 2])
@@ -162,7 +162,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testIncorrectPredictionsColumnLabels(self):
"""Same as above but labels is a rank-2 tensor."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -174,7 +174,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testIncorrectPredictionsZeroWeights(self):
"""Loss is 0 when all weights are missing even if predictions are wrong."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -185,7 +185,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testNonZeroLossWithPythonScalarWeights(self):
"""Weighted loss is correctly computed when weights is a python scalar."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -195,7 +195,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testNonZeroLossWithScalarTensorWeights(self):
"""Weighted loss is correctly computed when weights is a rank-0 tensor."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -205,7 +205,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testNonZeroLossWith1DTensorWeightsColumnLabels(self):
"""Weighted loss is correctly computed when weights is a rank-0 tensor."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -216,7 +216,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testNonZeroLossWith2DTensorWeights1DLabelsSomeWeightsMissing(self):
"""Weighted loss is correctly computed when weights is a rank-0 tensor."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0], [1.6, 1.8, -4.0]])
labels = constant_op.constant([1, 0, 2, 1])
diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
index 2ff4d41d75..bad0a596a7 100644
--- a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
+++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
@@ -58,7 +58,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
def testInvalidInputShape(self):
x = constant_op.constant([[2.0, 1.0]])
- with self.test_session():
+ with self.cached_session():
rffm = RandomFourierFeatureMapper(3, 10)
with self.assertRaisesWithPredicateMatch(
dense_kernel_mapper.InvalidShapeError,
@@ -70,7 +70,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
x2 = constant_op.constant([[1.0, -1.0, 2.0], [-1.0, 10.0, 1.0],
[4.0, -2.0, -1.0]])
- with self.test_session():
+ with self.cached_session():
rffm = RandomFourierFeatureMapper(3, 10, 1.0)
mapped_x1 = rffm.map(x1)
mapped_x2 = rffm.map(x2)
@@ -80,7 +80,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
def testSameOmegaReused(self):
x = constant_op.constant([[2.0, 1.0, 0.0]])
- with self.test_session():
+ with self.cached_session():
rffm = RandomFourierFeatureMapper(3, 100)
mapped_x = rffm.map(x)
mapped_x_copy = rffm.map(x)
@@ -93,7 +93,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
y = constant_op.constant([[1.0, -1.0, 2.0]])
stddev = 3.0
- with self.test_session():
+ with self.cached_session():
# The mapped dimension is fairly small, so the kernel approximation is
# very rough.
rffm1 = RandomFourierFeatureMapper(3, 100, stddev)
@@ -113,7 +113,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
y = constant_op.constant([[1.0, -1.0, 2.0]])
stddev = 3.0
- with self.test_session():
+ with self.cached_session():
# The mapped dimension is fairly small, so the kernel approximation is
# very rough.
rffm = RandomFourierFeatureMapper(3, 100, stddev, seed=0)
@@ -139,7 +139,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
normalized_points = [nn.l2_normalize(point, dim=1) for point in points]
total_absolute_error = 0.0
- with self.test_session():
+ with self.cached_session():
rffm = RandomFourierFeatureMapper(input_dim, mapped_dim, stddev, seed=0)
# Cache mappings so that they are not computed multiple times.
cached_mappings = dict((point, rffm.map(point))
diff --git a/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py
index 7289b45c50..bf89922318 100644
--- a/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py
+++ b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py
@@ -64,7 +64,7 @@ class KinesisDatasetTest(test.TestCase):
init_batch_op = iterator.make_initializer(batch_dataset)
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Basic test: read from shard 0 of stream 1.
sess.run(init_op, feed_dict={stream: stream_name, num_epochs: 1})
for i in range(10):
@@ -108,7 +108,7 @@ class KinesisDatasetTest(test.TestCase):
get_next = iterator.get_next()
data = list()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Basic test: read from shard 0 of stream 2.
sess.run(
init_op, feed_dict={
diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py
index ca2df95ba4..75806dbbeb 100644
--- a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py
+++ b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py
@@ -17,15 +17,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.kinesis.python.ops import kinesis_op_loader # pylint: disable=unused-import
from tensorflow.contrib.kinesis.python.ops import gen_dataset_ops
-from tensorflow.python.data.ops.dataset_ops import Dataset
+from tensorflow.contrib.kinesis.python.ops import kinesis_op_loader # pylint: disable=unused-import
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-class KinesisDataset(Dataset):
+class KinesisDataset(dataset_ops.DatasetSource):
"""A Kinesis Dataset that consumes the message.
Kinesis is a managed service provided by AWS for data streaming.
diff --git a/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py b/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py
index 28ddaa69a1..155d06a08e 100644
--- a/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py
+++ b/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py
@@ -45,7 +45,7 @@ class SparseCrossOpTest(test.TestCase):
'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2',
'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_dense(self):
@@ -66,7 +66,7 @@ class SparseCrossOpTest(test.TestCase):
'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2',
'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_integer_mixed_string_sparse(self):
@@ -80,7 +80,7 @@ class SparseCrossOpTest(test.TestCase):
'333_X_batch2-FC2-F1', '333_X_batch2-FC2-F2', '55555_X_batch2-FC2-F1',
'55555_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_integer_mixed_string_dense(self):
@@ -99,7 +99,7 @@ class SparseCrossOpTest(test.TestCase):
'55555_X_batch2-FC2-F1', '55555_X_batch2-FC2-F2',
'999999_X_batch2-FC2-F1', '999999_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_sparse_cross_dense(self):
@@ -117,7 +117,7 @@ class SparseCrossOpTest(test.TestCase):
'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2',
'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_integer_sparse_input(self):
@@ -133,7 +133,7 @@ class SparseCrossOpTest(test.TestCase):
'333_X_batch2-FC2-F1', '333_X_batch2-FC2-F2',
'5555_X_batch2-FC2-F1', '5555_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_permutation_3x3x3(self):
@@ -176,7 +176,7 @@ class SparseCrossOpTest(test.TestCase):
'batch1-FC1-F3_X_batch1-FC2-F3_X_batch1-FC3-F2',
'batch1-FC1-F3_X_batch1-FC2-F3_X_batch1-FC3-F3'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_permutation_3x1x2(self):
@@ -196,7 +196,7 @@ class SparseCrossOpTest(test.TestCase):
'batch1-FC1-F3_X_batch1-FC2-F1_X_batch1-FC3-F1',
'batch1-FC1-F3_X_batch1-FC2-F1_X_batch1-FC3-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_large_batch(self):
@@ -229,7 +229,7 @@ class SparseCrossOpTest(test.TestCase):
])
expected_out = self._sparse_tensor(col_out)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_one_column_empty(self):
@@ -242,7 +242,7 @@ class SparseCrossOpTest(test.TestCase):
self._sparse_tensor([], 1),
self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']])
])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_empty(sess.run(op))
def test_some_columns_empty(self):
@@ -261,7 +261,7 @@ class SparseCrossOpTest(test.TestCase):
'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F1',
'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F2'
]], 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_all_columns_empty(self):
@@ -273,7 +273,7 @@ class SparseCrossOpTest(test.TestCase):
self._sparse_tensor([]), self._sparse_tensor([]),
self._sparse_tensor([])
])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_empty(sess.run(op))
def test_hashed_output_zero_bucket(self):
@@ -288,7 +288,7 @@ class SparseCrossOpTest(test.TestCase):
hashed_output=True)
# Check actual hashed output to prevent unintentional hashing changes.
expected_out = self._sparse_tensor([[3735511728867393167]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_hashed_output_zero_bucket_v2(self):
@@ -304,7 +304,7 @@ class SparseCrossOpTest(test.TestCase):
hash_key=layers.SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY)
# Check actual hashed output to prevent unintentional hashing changes.
expected_out = self._sparse_tensor([[1971693436396284976]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
# TODO(sibyl-Aix6ihai): Add benchmark to compare Hashed vs Non-hashed.
@@ -321,7 +321,7 @@ class SparseCrossOpTest(test.TestCase):
num_buckets=100)
# Check actual hashed output to prevent unintentional hashing changes.
expected_out = self._sparse_tensor([[74]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_hashed_output_v2(self):
@@ -338,7 +338,7 @@ class SparseCrossOpTest(test.TestCase):
hash_key=layers.SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY)
# Check actual hashed output to prevent unintentional hashing changes.
expected_out = self._sparse_tensor([[83]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_hashed_output_v1_has_collision(self):
@@ -384,7 +384,7 @@ class SparseCrossOpTest(test.TestCase):
],
hashed_output=True,
num_buckets=1000)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
out = sess.run(op)
self.assertEqual(6, len(out.values))
self.assertAllEqual([[0, i] for i in range(6)], out.indices)
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 85af9de4e4..3b7ae72e9c 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -2360,7 +2360,7 @@ class BatchNormTest(test.TestCase):
batch_size * height * width, expected_var)
images = constant_op.constant(
image_values, shape=image_shape, dtype=dtypes.float32)
- is_training = variables_lib.Variable(True)
+ is_training = variables_lib.VariableV1(True)
output = _layers.batch_norm(
images,
decay=0.1,
@@ -2507,7 +2507,7 @@ class BatchNormTest(test.TestCase):
batch_size * height * width, expected_var)
images = constant_op.constant(
image_values, shape=image_shape, dtype=dtypes.float32)
- is_training = variables_lib.Variable(True)
+ is_training = variables_lib.VariableV1(True)
output = _layers.batch_norm(
images,
decay=0.1,
diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py
index 69d927e1b3..2fdcd849b0 100644
--- a/tensorflow/contrib/layers/python/layers/optimizers.py
+++ b/tensorflow/contrib/layers/python/layers/optimizers.py
@@ -21,8 +21,6 @@ from __future__ import print_function
import six
from tensorflow.contrib import framework as contrib_framework
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
@@ -433,12 +431,11 @@ def _multiply_gradients(grads_and_vars, gradient_multipliers):
if (grad is not None and
(var in gradient_multipliers or var.name in gradient_multipliers)):
key = var if var in gradient_multipliers else var.name
- multiplier = constant_op.constant(
- gradient_multipliers[key], dtype=dtypes.float32)
+ multiplier = gradient_multipliers[key]
if isinstance(grad, ops.IndexedSlices):
grad_values = grad.values * multiplier
grad = ops.IndexedSlices(grad_values, grad.indices, grad.dense_shape)
else:
- grad *= multiplier
+ grad *= math_ops.cast(multiplier, grad.dtype)
multiplied_grads_and_vars.append((grad, var))
return multiplied_grads_and_vars
diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py
index 29dede2a49..b4d1239e76 100644
--- a/tensorflow/contrib/layers/python/layers/optimizers_test.py
+++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py
@@ -250,6 +250,42 @@ class OptimizersTest(test.TestCase):
self.assertAlmostEqual(var_value, 6.5, 4)
self.assertEqual(global_step_value, 1)
+ def testGradientMultiplyInt32Tensor(self):
+ with self.cached_session() as session:
+ x, var, loss, global_step = _setup_model()
+ v = array_ops.placeholder(dtypes.float32, [])
+ train = optimizers_lib.optimize_loss(
+ loss,
+ global_step,
+ learning_rate=0.1,
+ optimizer="SGD",
+ gradient_multipliers={var: v})
+ variables.global_variables_initializer().run()
+ session.run(train, feed_dict={x: 5, v: 7.})
+ var_value, global_step_value = session.run([var, global_step])
+ # var(0) = 10, x = 5, var(0)/dx = 5,
+ # var(1) = var(0) - learning_rate * gradient_multiplier * var(0)/dx
+ self.assertAlmostEqual(var_value, 6.5, 4)
+ self.assertEqual(global_step_value, 1)
+
+ def testGradientMultiplyInt64Tensor(self):
+ with self.cached_session() as session:
+ x, var, loss, global_step = _setup_model()
+ v = array_ops.placeholder(dtypes.float64, [])
+ train = optimizers_lib.optimize_loss(
+ loss,
+ global_step,
+ learning_rate=0.1,
+ optimizer="SGD",
+ gradient_multipliers={var: v})
+ variables.global_variables_initializer().run()
+ session.run(train, feed_dict={x: 5, v: 7.})
+ var_value, global_step_value = session.run([var, global_step])
+ # var(0) = 10, x = 5, var(0)/dx = 5,
+ # var(1) = var(0) - learning_rate * gradient_multiplier * var(0)/dx
+ self.assertAlmostEqual(var_value, 6.5, 4)
+ self.assertEqual(global_step_value, 1)
+
def testIgnoreVariablesWithNoGradients(self):
_, _, loss, global_step = _setup_model()
diff --git a/tensorflow/contrib/layers/python/layers/target_column.py b/tensorflow/contrib/layers/python/layers/target_column.py
index 69bb6be814..8a6b4f68a8 100644
--- a/tensorflow/contrib/layers/python/layers/target_column.py
+++ b/tensorflow/contrib/layers/python/layers/target_column.py
@@ -396,7 +396,7 @@ class _BinarySvmTargetColumn(_MultiClassTargetColumn):
def _mean_squared_loss(logits, target):
# To prevent broadcasting inside "-".
if len(target.get_shape()) == 1:
- target = array_ops.expand_dims(target, dim=[1])
+ target = array_ops.expand_dims(target, axis=1)
logits.get_shape().assert_is_compatible_with(target.get_shape())
return math_ops.square(logits - math_ops.to_float(target))
@@ -405,7 +405,7 @@ def _mean_squared_loss(logits, target):
def _log_loss_with_two_classes(logits, target):
# sigmoid_cross_entropy_with_logits requires [batch_size, 1] target.
if len(target.get_shape()) == 1:
- target = array_ops.expand_dims(target, dim=[1])
+ target = array_ops.expand_dims(target, axis=1)
loss_vec = nn.sigmoid_cross_entropy_with_logits(
labels=math_ops.to_float(target), logits=logits)
return loss_vec
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index ded93d4a7f..c6f79e00d5 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -563,10 +563,10 @@ def _mean_squared_loss(labels, logits, weights=None):
labels = ops.convert_to_tensor(labels)
# To prevent broadcasting inside "-".
if len(labels.get_shape()) == 1:
- labels = array_ops.expand_dims(labels, axis=(1,))
+ labels = array_ops.expand_dims(labels, axis=1)
# TODO(zakaria): make sure it does not recreate the broadcast bug.
if len(logits.get_shape()) == 1:
- logits = array_ops.expand_dims(logits, axis=(1,))
+ logits = array_ops.expand_dims(logits, axis=1)
logits.get_shape().assert_is_compatible_with(labels.get_shape())
loss = math_ops.square(logits - math_ops.to_float(labels), name=name)
return _compute_weighted_loss(loss, weights)
@@ -579,10 +579,10 @@ def _poisson_loss(labels, logits, weights=None):
labels = ops.convert_to_tensor(labels)
# To prevent broadcasting inside "-".
if len(labels.get_shape()) == 1:
- labels = array_ops.expand_dims(labels, axis=(1,))
+ labels = array_ops.expand_dims(labels, axis=1)
# TODO(zakaria): make sure it does not recreate the broadcast bug.
if len(logits.get_shape()) == 1:
- logits = array_ops.expand_dims(logits, axis=(1,))
+ logits = array_ops.expand_dims(logits, axis=1)
logits.get_shape().assert_is_compatible_with(labels.get_shape())
loss = nn.log_poisson_loss(labels, logits, compute_full_loss=True,
name=name)
@@ -797,7 +797,7 @@ def _log_loss_with_two_classes(labels, logits, weights=None):
# TODO(ptucker): This will break for dynamic shapes.
# sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels.
if len(labels.get_shape()) == 1:
- labels = array_ops.expand_dims(labels, axis=(1,))
+ labels = array_ops.expand_dims(labels, axis=1)
loss = nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits,
name=name)
return _compute_weighted_loss(loss, weights)
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
index d5c02124ac..a160cb54a3 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions_test.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
@@ -162,9 +162,9 @@ class GraphActionsTest(test.TestCase):
Tuple of 3 `Tensor` objects, 2 input and 1 output.
"""
variables_lib.create_global_step()
- in0 = variables.Variable(1.0)
+ in0 = variables.VariableV1(1.0)
in1 = variables_lib.local_variable(2.0)
- fake_table = variables.Variable(
+ fake_table = variables.VariableV1(
3.0,
trainable=False,
collections=['fake_tables'],
@@ -234,7 +234,7 @@ class GraphActionsTest(test.TestCase):
self.assertTrue(test_ops.resource_initialized_op(handle).eval())
def test_infer_different_default_graph(self):
- with self.test_session():
+ with self.cached_session():
self._assert_ckpt(self._output_dir, False)
with ops.Graph().as_default():
in0, in1, out = self._build_inference_graph()
@@ -312,8 +312,8 @@ class GraphActionsTest(test.TestCase):
def test_evaluate_ready_for_local_init(self):
with ops.Graph().as_default() as g, self.session(g):
variables_lib.create_global_step()
- v = variables.Variable(1.0)
- variables.Variable(
+ v = variables.VariableV1(1.0)
+ variables.VariableV1(
v + 1, collections=[ops.GraphKeys.LOCAL_VARIABLES], trainable=False)
ready_for_local_init_op = variables.report_uninitialized_variables(
variables.global_variables())
@@ -456,9 +456,9 @@ class GraphActionsTrainTest(test.TestCase):
Tuple of 3 `Tensor` objects, 2 input and 1 output.
"""
variables_lib.create_global_step()
- in0 = variables.Variable(1.0)
+ in0 = variables.VariableV1(1.0)
in1 = variables_lib.local_variable(2.0)
- fake_table = variables.Variable(
+ fake_table = variables.VariableV1(
3.0,
trainable=False,
collections=['fake_tables'],
diff --git a/tensorflow/contrib/learn/python/learn/monitors_test.py b/tensorflow/contrib/learn/python/learn/monitors_test.py
index 83e48a36e7..d4a7169bb6 100644
--- a/tensorflow/contrib/learn/python/learn/monitors_test.py
+++ b/tensorflow/contrib/learn/python/learn/monitors_test.py
@@ -247,7 +247,7 @@ class MonitorsTest(test.TestCase):
def test_logging_trainable(self):
with ops.Graph().as_default() as g, self.session(g):
- var = variables.Variable(constant_op.constant(42.0), name='foo')
+ var = variables.VariableV1(constant_op.constant(42.0), name='foo')
var.initializer.run()
cof = constant_op.constant(1.0)
loss = math_ops.subtract(
@@ -261,7 +261,7 @@ class MonitorsTest(test.TestCase):
with ops.Graph().as_default() as g, self.session(g):
log_dir = 'log/dir'
summary_writer = testing.FakeSummaryWriter(log_dir, g)
- var = variables.Variable(0.0)
+ var = variables.VariableV1(0.0)
var.initializer.run()
tensor = state_ops.assign_add(var, 1.0)
summary_op = summary.scalar('my_summary', tensor)
@@ -526,8 +526,8 @@ class MonitorsTest(test.TestCase):
monitor0 = learn.monitors.GraphDump()
monitor1 = learn.monitors.GraphDump()
with ops.Graph().as_default() as g, self.session(g):
- const_var = variables.Variable(42.0, name='my_const')
- counter_var = variables.Variable(0.0, name='my_counter')
+ const_var = variables.VariableV1(42.0, name='my_const')
+ counter_var = variables.VariableV1(0.0, name='my_counter')
assign_add = state_ops.assign_add(counter_var, 1.0, name='my_assign_add')
variables.global_variables_initializer().run()
@@ -569,7 +569,7 @@ class MonitorsTest(test.TestCase):
monitor = learn.monitors.CaptureVariable(
var_name='my_assign_add:0', every_n=8, first_n=2)
with ops.Graph().as_default() as g, self.session(g):
- var = variables.Variable(0.0, name='my_var')
+ var = variables.VariableV1(0.0, name='my_var')
var.initializer.run()
state_ops.assign_add(var, 1.0, name='my_assign_add')
self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10)
diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
index 2f33a2b74d..0e5ea6b9f7 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
@@ -47,7 +47,7 @@ from tensorflow.python.training import adam
class Seq2SeqTest(test.TestCase):
def testRNNDecoder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
@@ -65,7 +65,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0].shape)
def testBasicRNNSeq2Seq(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
@@ -81,7 +81,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0].shape)
def testTiedRNNSeq2Seq(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
@@ -98,7 +98,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0].shape)
def testEmbeddingRNNDecoder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
@@ -124,7 +124,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0].h.shape)
def testEmbeddingRNNSeq2Seq(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
enc_inp = [
@@ -228,7 +228,7 @@ class Seq2SeqTest(test.TestCase):
self.assertAllClose(res1, res3)
def testEmbeddingTiedRNNSeq2Seq(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
enc_inp = [
@@ -316,7 +316,7 @@ class Seq2SeqTest(test.TestCase):
self.assertAllClose(res1, res3)
def testAttentionDecoder1(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
cell_fn = lambda: rnn_cell.GRUCell(2)
@@ -341,7 +341,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0].shape)
def testAttentionDecoder2(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
cell_fn = lambda: rnn_cell.GRUCell(2)
@@ -367,7 +367,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0].shape)
def testDynamicAttentionDecoder1(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
cell_fn = lambda: rnn_cell.GRUCell(2)
@@ -391,7 +391,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0].shape)
def testDynamicAttentionDecoder2(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
cell_fn = lambda: rnn_cell.GRUCell(2)
@@ -416,7 +416,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0].shape)
def testAttentionDecoderStateIsTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
single_cell = lambda: rnn_cell.BasicLSTMCell( # pylint: disable=g-long-lambda
@@ -448,7 +448,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0][1].h.shape)
def testDynamicAttentionDecoderStateIsTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
cell_fn = lambda: rnn_cell.MultiRNNCell( # pylint: disable=g-long-lambda
@@ -479,7 +479,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0][1].h.shape)
def testEmbeddingAttentionDecoder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
@@ -513,7 +513,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0].shape)
def testEmbeddingAttentionSeq2Seq(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
enc_inp = [
@@ -622,7 +622,7 @@ class Seq2SeqTest(test.TestCase):
# self.assertAllClose(res1, res3)
def testOne2ManyRNNSeq2Seq(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
enc_inp = [
@@ -712,7 +712,7 @@ class Seq2SeqTest(test.TestCase):
self.assertAllClose(res1, res3)
def testSequenceLoss(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
logits = [constant_op.constant(i + 0.5, shape=[2, 5]) for i in range(3)]
targets = [
constant_op.constant(
@@ -748,7 +748,7 @@ class Seq2SeqTest(test.TestCase):
self.assertAllClose(9.656628, res)
def testSequenceLossByExample(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output_classes = 5
logits = [
constant_op.constant(
@@ -778,7 +778,7 @@ class Seq2SeqTest(test.TestCase):
# classes = 10
# buckets = [(4, 4), (8, 8)]
- # with self.test_session():
+ # with self.cached_session():
# # Here comes a sample Seq2Seq model using GRU cells.
# def SampleGRUSeq2Seq(enc_inp, dec_inp, weights, per_example_loss):
# """Example sequence-to-sequence model that uses GRU cells."""
@@ -839,7 +839,7 @@ class Seq2SeqTest(test.TestCase):
random.seed(111)
np.random.seed(111)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# We use sampled softmax so we keep output projection separate.
w = variable_scope.get_variable("proj_w", [24, classes])
w_t = array_ops.transpose(w)
diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD
deleted file mode 100644
index 78b7970069..0000000000
--- a/tensorflow/contrib/linalg/BUILD
+++ /dev/null
@@ -1,44 +0,0 @@
-# Description:
-# Contains classes that provide access to common method of a [batch] matrix,
-# without the need to instantiate the matrix.
-# This allows for exploitation of structure, as well as a generic interface
-# suitable for iterative solvers.
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-package(default_visibility = ["//tensorflow:__subpackages__"])
-
-load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-
-py_library(
- name = "linalg_py",
- srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:util",
- "//tensorflow/python/ops/linalg",
- "@six_archive//:six",
- ],
-)
-
-cuda_py_test(
- name = "linear_operator_addition_test",
- size = "small",
- srcs = ["python/kernel_tests/linear_operator_addition_test.py"],
- additional_deps = [
- ":linalg_py",
- "//third_party/py/numpy",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:platform_test",
- ],
-)
diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py
deleted file mode 100644
index cbe4c03e4d..0000000000
--- a/tensorflow/contrib/linalg/__init__.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Linear algebra libraries.
-
-See the[Contrib Linalg](https://tensorflow.org/api_guides/python/contrib.linalg)
-guide.
-
-@@LinearOperator
-@@LinearOperatorBlockDiag
-@@LinearOperatorCirculant
-@@LinearOperatorCirculant2D
-@@LinearOperatorCirculant3D
-@@LinearOperatorDiag
-@@LinearOperatorIdentity
-@@LinearOperatorScaledIdentity
-@@LinearOperatorFullMatrix
-@@LinearOperatorKronecker
-@@LinearOperatorLowerTriangular
-@@LinearOperatorLowRankUpdate
-@@LinearOperatorComposition
-@@add_operators
-
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
-
-from tensorflow.contrib.linalg.python.ops.linear_operator_addition import *
-from tensorflow.python.ops.linalg.linear_operator import *
-from tensorflow.python.ops.linalg.linear_operator_block_diag import *
-from tensorflow.python.ops.linalg.linear_operator_circulant import *
-from tensorflow.python.ops.linalg.linear_operator_composition import *
-from tensorflow.python.ops.linalg.linear_operator_diag import *
-from tensorflow.python.ops.linalg.linear_operator_full_matrix import *
-from tensorflow.python.ops.linalg.linear_operator_identity import *
-from tensorflow.python.ops.linalg.linear_operator_kronecker import *
-from tensorflow.python.ops.linalg.linear_operator_low_rank_update import *
-from tensorflow.python.ops.linalg.linear_operator_lower_triangular import *
-
-# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member
-
-from tensorflow.python.util.all_util import remove_undocumented
-
-remove_undocumented(__name__)
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py
deleted file mode 100644
index 6a72df6dfd..0000000000
--- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py
+++ /dev/null
@@ -1,412 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.linalg.python.ops import linear_operator_addition
-from tensorflow.python.framework import random_seed
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops.linalg import linalg as linalg_lib
-from tensorflow.python.platform import test
-
-linalg = linalg_lib
-random_seed.set_random_seed(23)
-rng = np.random.RandomState(0)
-
-add_operators = linear_operator_addition.add_operators
-
-
-# pylint: disable=unused-argument
-class _BadAdder(linear_operator_addition._Adder):
- """Adder that will fail if used."""
-
- def can_add(self, op1, op2):
- raise AssertionError("BadAdder.can_add called!")
-
- def _add(self, op1, op2, operator_name, hints):
- raise AssertionError("This line should not be reached")
-
-
-# pylint: enable=unused-argument
-
-
-class LinearOperatorAdditionCorrectnessTest(test.TestCase):
- """Tests correctness of addition with combinations of a few Adders.
-
- Tests here are done with the _DEFAULT_ADDITION_TIERS, which means
- add_operators should reduce all operators resulting in one single operator.
-
- This shows that we are able to correctly combine adders using the tiered
- system. All Adders should be tested separately, and there is no need to test
- every Adder within this class.
- """
-
- def test_one_operator_is_returned_unchanged(self):
- op_a = linalg.LinearOperatorDiag([1., 1.])
- op_sum = add_operators([op_a])
- self.assertEqual(1, len(op_sum))
- self.assertTrue(op_sum[0] is op_a)
-
- def test_at_least_one_operators_required(self):
- with self.assertRaisesRegexp(ValueError, "must contain at least one"):
- add_operators([])
-
- def test_attempting_to_add_numbers_raises(self):
- with self.assertRaisesRegexp(TypeError, "contain only LinearOperator"):
- add_operators([1, 2])
-
- def test_two_diag_operators(self):
- op_a = linalg.LinearOperatorDiag(
- [1., 1.], is_positive_definite=True, name="A")
- op_b = linalg.LinearOperatorDiag(
- [2., 2.], is_positive_definite=True, name="B")
- with self.test_session():
- op_sum = add_operators([op_a, op_b])
- self.assertEqual(1, len(op_sum))
- op = op_sum[0]
- self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag))
- self.assertAllClose([[3., 0.], [0., 3.]], op.to_dense().eval())
- # Adding positive definite operators produces positive def.
- self.assertTrue(op.is_positive_definite)
- # Real diagonal ==> self-adjoint.
- self.assertTrue(op.is_self_adjoint)
- # Positive definite ==> non-singular
- self.assertTrue(op.is_non_singular)
- # Enforce particular name for this simple case
- self.assertEqual("Add/B__A/", op.name)
-
- def test_three_diag_operators(self):
- op1 = linalg.LinearOperatorDiag(
- [1., 1.], is_positive_definite=True, name="op1")
- op2 = linalg.LinearOperatorDiag(
- [2., 2.], is_positive_definite=True, name="op2")
- op3 = linalg.LinearOperatorDiag(
- [3., 3.], is_positive_definite=True, name="op3")
- with self.test_session():
- op_sum = add_operators([op1, op2, op3])
- self.assertEqual(1, len(op_sum))
- op = op_sum[0]
- self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag))
- self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval())
- # Adding positive definite operators produces positive def.
- self.assertTrue(op.is_positive_definite)
- # Real diagonal ==> self-adjoint.
- self.assertTrue(op.is_self_adjoint)
- # Positive definite ==> non-singular
- self.assertTrue(op.is_non_singular)
-
- def test_diag_tril_diag(self):
- op1 = linalg.LinearOperatorDiag(
- [1., 1.], is_non_singular=True, name="diag_a")
- op2 = linalg.LinearOperatorLowerTriangular(
- [[2., 0.], [0., 2.]],
- is_self_adjoint=True,
- is_non_singular=True,
- name="tril")
- op3 = linalg.LinearOperatorDiag(
- [3., 3.], is_non_singular=True, name="diag_b")
- with self.test_session():
- op_sum = add_operators([op1, op2, op3])
- self.assertEqual(1, len(op_sum))
- op = op_sum[0]
- self.assertTrue(isinstance(op, linalg_lib.LinearOperatorLowerTriangular))
- self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval())
-
- # The diag operators will be self-adjoint (because real and diagonal).
- # The TriL operator has the self-adjoint hint set.
- self.assertTrue(op.is_self_adjoint)
-
- # Even though op1/2/3 are non-singular, this does not imply op is.
- # Since no custom hint was provided, we default to None (unknown).
- self.assertEqual(None, op.is_non_singular)
-
- def test_matrix_diag_tril_diag_uses_custom_name(self):
- op0 = linalg.LinearOperatorFullMatrix(
- [[-1., -1.], [-1., -1.]], name="matrix")
- op1 = linalg.LinearOperatorDiag([1., 1.], name="diag_a")
- op2 = linalg.LinearOperatorLowerTriangular(
- [[2., 0.], [1.5, 2.]], name="tril")
- op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b")
- with self.test_session():
- op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator")
- self.assertEqual(1, len(op_sum))
- op = op_sum[0]
- self.assertTrue(isinstance(op, linalg_lib.LinearOperatorFullMatrix))
- self.assertAllClose([[5., -1.], [0.5, 5.]], op.to_dense().eval())
- self.assertEqual("my_operator", op.name)
-
- def test_incompatible_domain_dimensions_raises(self):
- op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3))
- op2 = linalg.LinearOperatorDiag(rng.rand(2, 4))
- with self.assertRaisesRegexp(ValueError, "must.*same domain dimension"):
- add_operators([op1, op2])
-
- def test_incompatible_range_dimensions_raises(self):
- op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3))
- op2 = linalg.LinearOperatorDiag(rng.rand(3, 3))
- with self.assertRaisesRegexp(ValueError, "must.*same range dimension"):
- add_operators([op1, op2])
-
- def test_non_broadcastable_batch_shape_raises(self):
- op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3))
- op2 = linalg.LinearOperatorDiag(rng.rand(4, 3, 3))
- with self.assertRaisesRegexp(ValueError, "Incompatible shapes"):
- add_operators([op1, op2])
-
-
-class LinearOperatorOrderOfAdditionTest(test.TestCase):
- """Test that the order of addition is done as specified by tiers."""
-
- def test_tier_0_additions_done_in_tier_0(self):
- diag1 = linalg.LinearOperatorDiag([1.])
- diag2 = linalg.LinearOperatorDiag([1.])
- diag3 = linalg.LinearOperatorDiag([1.])
- addition_tiers = [
- [linear_operator_addition._AddAndReturnDiag()],
- [_BadAdder()],
- ]
- # Should not raise since all were added in tier 0, and tier 1 (with the
- # _BadAdder) was never reached.
- op_sum = add_operators([diag1, diag2, diag3], addition_tiers=addition_tiers)
- self.assertEqual(1, len(op_sum))
- self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorDiag))
-
- def test_tier_1_additions_done_by_tier_1(self):
- diag1 = linalg.LinearOperatorDiag([1.])
- diag2 = linalg.LinearOperatorDiag([1.])
- tril = linalg.LinearOperatorLowerTriangular([[1.]])
- addition_tiers = [
- [linear_operator_addition._AddAndReturnDiag()],
- [linear_operator_addition._AddAndReturnTriL()],
- [_BadAdder()],
- ]
- # Should not raise since all were added by tier 1, and the
- # _BadAdder) was never reached.
- op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
- self.assertEqual(1, len(op_sum))
- self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular))
-
- def test_tier_1_additions_done_by_tier_1_with_order_flipped(self):
- diag1 = linalg.LinearOperatorDiag([1.])
- diag2 = linalg.LinearOperatorDiag([1.])
- tril = linalg.LinearOperatorLowerTriangular([[1.]])
- addition_tiers = [
- [linear_operator_addition._AddAndReturnTriL()],
- [linear_operator_addition._AddAndReturnDiag()],
- [_BadAdder()],
- ]
- # Tier 0 could convert to TriL, and this converted everything to TriL,
- # including the Diags.
- # Tier 1 was never used.
- # Tier 2 was never used (therefore, _BadAdder didn't raise).
- op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
- self.assertEqual(1, len(op_sum))
- self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular))
-
- def test_cannot_add_everything_so_return_more_than_one_operator(self):
- diag1 = linalg.LinearOperatorDiag([1.])
- diag2 = linalg.LinearOperatorDiag([2.])
- tril5 = linalg.LinearOperatorLowerTriangular([[5.]])
- addition_tiers = [
- [linear_operator_addition._AddAndReturnDiag()],
- ]
- # Tier 0 (the only tier) can only convert to Diag, so it combines the two
- # diags, but the TriL is unchanged.
- # Result should contain two operators, one Diag, one TriL.
- op_sum = add_operators([diag1, diag2, tril5], addition_tiers=addition_tiers)
- self.assertEqual(2, len(op_sum))
- found_diag = False
- found_tril = False
- with self.test_session():
- for op in op_sum:
- if isinstance(op, linalg.LinearOperatorDiag):
- found_diag = True
- self.assertAllClose([[3.]], op.to_dense().eval())
- if isinstance(op, linalg.LinearOperatorLowerTriangular):
- found_tril = True
- self.assertAllClose([[5.]], op.to_dense().eval())
- self.assertTrue(found_diag and found_tril)
-
- def test_intermediate_tier_is_not_skipped(self):
- diag1 = linalg.LinearOperatorDiag([1.])
- diag2 = linalg.LinearOperatorDiag([1.])
- tril = linalg.LinearOperatorLowerTriangular([[1.]])
- addition_tiers = [
- [linear_operator_addition._AddAndReturnDiag()],
- [_BadAdder()],
- [linear_operator_addition._AddAndReturnTriL()],
- ]
- # tril cannot be added in tier 0, and the intermediate tier 1 with the
- # BadAdder will catch it and raise.
- with self.assertRaisesRegexp(AssertionError, "BadAdder.can_add called"):
- add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
-
-
-class AddAndReturnScaledIdentityTest(test.TestCase):
-
- def setUp(self):
- self._adder = linear_operator_addition._AddAndReturnScaledIdentity()
-
- def test_identity_plus_identity(self):
- id1 = linalg.LinearOperatorIdentity(num_rows=2)
- id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
- hints = linear_operator_addition._Hints(
- is_positive_definite=True, is_non_singular=True)
-
- self.assertTrue(self._adder.can_add(id1, id2))
- operator = self._adder.add(id1, id2, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity))
-
- with self.test_session():
- self.assertAllClose(2 *
- linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
- operator.to_dense().eval())
- self.assertTrue(operator.is_positive_definite)
- self.assertTrue(operator.is_non_singular)
- self.assertEqual("my_operator", operator.name)
-
- def test_identity_plus_scaled_identity(self):
- id1 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
- id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=2.2)
- hints = linear_operator_addition._Hints(
- is_positive_definite=True, is_non_singular=True)
-
- self.assertTrue(self._adder.can_add(id1, id2))
- operator = self._adder.add(id1, id2, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity))
-
- with self.test_session():
- self.assertAllClose(3.2 *
- linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
- operator.to_dense().eval())
- self.assertTrue(operator.is_positive_definite)
- self.assertTrue(operator.is_non_singular)
- self.assertEqual("my_operator", operator.name)
-
- def test_scaled_identity_plus_scaled_identity(self):
- id1 = linalg.LinearOperatorScaledIdentity(
- num_rows=2, multiplier=[2.2, 2.2, 2.2])
- id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=-1.0)
- hints = linear_operator_addition._Hints(
- is_positive_definite=True, is_non_singular=True)
-
- self.assertTrue(self._adder.can_add(id1, id2))
- operator = self._adder.add(id1, id2, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity))
-
- with self.test_session():
- self.assertAllClose(1.2 *
- linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
- operator.to_dense().eval())
- self.assertTrue(operator.is_positive_definite)
- self.assertTrue(operator.is_non_singular)
- self.assertEqual("my_operator", operator.name)
-
-
-class AddAndReturnDiagTest(test.TestCase):
-
- def setUp(self):
- self._adder = linear_operator_addition._AddAndReturnDiag()
-
- def test_identity_plus_identity_returns_diag(self):
- id1 = linalg.LinearOperatorIdentity(num_rows=2)
- id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
- hints = linear_operator_addition._Hints(
- is_positive_definite=True, is_non_singular=True)
-
- self.assertTrue(self._adder.can_add(id1, id2))
- operator = self._adder.add(id1, id2, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorDiag))
-
- with self.test_session():
- self.assertAllClose(2 *
- linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
- operator.to_dense().eval())
- self.assertTrue(operator.is_positive_definite)
- self.assertTrue(operator.is_non_singular)
- self.assertEqual("my_operator", operator.name)
-
- def test_diag_plus_diag(self):
- diag1 = rng.rand(2, 3, 4)
- diag2 = rng.rand(4)
- op1 = linalg.LinearOperatorDiag(diag1)
- op2 = linalg.LinearOperatorDiag(diag2)
- hints = linear_operator_addition._Hints(
- is_positive_definite=True, is_non_singular=True)
-
- self.assertTrue(self._adder.can_add(op1, op2))
- operator = self._adder.add(op1, op2, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorDiag))
-
- with self.test_session():
- self.assertAllClose(
- linalg.LinearOperatorDiag(diag1 + diag2).to_dense().eval(),
- operator.to_dense().eval())
- self.assertTrue(operator.is_positive_definite)
- self.assertTrue(operator.is_non_singular)
- self.assertEqual("my_operator", operator.name)
-
-
-class AddAndReturnTriLTest(test.TestCase):
-
- def setUp(self):
- self._adder = linear_operator_addition._AddAndReturnTriL()
-
- def test_diag_plus_tril(self):
- diag = linalg.LinearOperatorDiag([1., 2.])
- tril = linalg.LinearOperatorLowerTriangular([[10., 0.], [30., 0.]])
- hints = linear_operator_addition._Hints(
- is_positive_definite=True, is_non_singular=True)
-
- self.assertTrue(self._adder.can_add(diag, diag))
- self.assertTrue(self._adder.can_add(diag, tril))
- operator = self._adder.add(diag, tril, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorLowerTriangular))
-
- with self.test_session():
- self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval())
- self.assertTrue(operator.is_positive_definite)
- self.assertTrue(operator.is_non_singular)
- self.assertEqual("my_operator", operator.name)
-
-
-class AddAndReturnMatrixTest(test.TestCase):
-
- def setUp(self):
- self._adder = linear_operator_addition._AddAndReturnMatrix()
-
- def test_diag_plus_diag(self):
- diag1 = linalg.LinearOperatorDiag([1., 2.])
- diag2 = linalg.LinearOperatorDiag([-1., 3.])
- hints = linear_operator_addition._Hints(
- is_positive_definite=False, is_non_singular=False)
-
- self.assertTrue(self._adder.can_add(diag1, diag2))
- operator = self._adder.add(diag1, diag2, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorFullMatrix))
-
- with self.test_session():
- self.assertAllClose([[0., 0.], [0., 5.]], operator.to_dense().eval())
- self.assertFalse(operator.is_positive_definite)
- self.assertFalse(operator.is_non_singular)
- self.assertEqual("my_operator", operator.name)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py b/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py
deleted file mode 100644
index 86130a2c07..0000000000
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py
+++ /dev/null
@@ -1,432 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Add one or more `LinearOperators` efficiently."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-
-import six
-
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops.linalg import linear_operator
-from tensorflow.python.ops.linalg import linear_operator_diag
-from tensorflow.python.ops.linalg import linear_operator_full_matrix
-from tensorflow.python.ops.linalg import linear_operator_identity
-from tensorflow.python.ops.linalg import linear_operator_lower_triangular
-
-__all__ = []
-
-
-def add_operators(operators,
- operator_name=None,
- addition_tiers=None,
- name=None):
- """Efficiently add one or more linear operators.
-
- Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of
- operators `[B1, B2,...]` such that
-
- ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).```
-
- The operators `Bk` result by adding some of the `Ak`, as allowed by
- `addition_tiers`.
-
- Example of efficient adding of diagonal operators.
-
- ```python
- A1 = LinearOperatorDiag(diag=[1., 1.], name="A1")
- A2 = LinearOperatorDiag(diag=[2., 2.], name="A2")
-
- # Use two tiers, the first contains an Adder that returns Diag. Since both
- # A1 and A2 are Diag, they can use this Adder. The second tier will not be
- # used.
- addition_tiers = [
- [_AddAndReturnDiag()],
- [_AddAndReturnMatrix()]]
- B_list = add_operators([A1, A2], addition_tiers=addition_tiers)
-
- len(B_list)
- ==> 1
-
- B_list[0].__class__.__name__
- ==> 'LinearOperatorDiag'
-
- B_list[0].to_dense()
- ==> [[3., 0.],
- [0., 3.]]
-
- B_list[0].name
- ==> 'Add/A1__A2/'
- ```
-
- Args:
- operators: Iterable of `LinearOperator` objects with same `dtype`, domain
- and range dimensions, and broadcastable batch shapes.
- operator_name: String name for returned `LinearOperator`. Defaults to
- concatenation of "Add/A__B/" that indicates the order of addition steps.
- addition_tiers: List tiers, like `[tier_0, tier_1, ...]`, where `tier_i`
- is a list of `Adder` objects. This function attempts to do all additions
- in tier `i` before trying tier `i + 1`.
- name: A name for this `Op`. Defaults to `add_operators`.
-
- Returns:
- Subclass of `LinearOperator`. Class and order of addition may change as new
- (and better) addition strategies emerge.
-
- Raises:
- ValueError: If `operators` argument is empty.
- ValueError: If shapes are incompatible.
- """
- # Default setting
- if addition_tiers is None:
- addition_tiers = _DEFAULT_ADDITION_TIERS
-
- # Argument checking.
- check_ops.assert_proper_iterable(operators)
- operators = list(reversed(operators))
- if len(operators) < 1:
- raise ValueError(
- "Argument 'operators' must contain at least one operator. "
- "Found: %s" % operators)
- if not all(
- isinstance(op, linear_operator.LinearOperator) for op in operators):
- raise TypeError(
- "Argument 'operators' must contain only LinearOperator instances. "
- "Found: %s" % operators)
- _static_check_for_same_dimensions(operators)
- _static_check_for_broadcastable_batch_shape(operators)
-
- graph_parents = []
- for operator in operators:
- graph_parents.extend(operator.graph_parents)
-
- with ops.name_scope(name or "add_operators", values=graph_parents):
-
- # Additions done in one of the tiers. Try tier 0, 1,...
- ops_to_try_at_next_tier = list(operators)
- for tier in addition_tiers:
- ops_to_try_at_this_tier = ops_to_try_at_next_tier
- ops_to_try_at_next_tier = []
- while ops_to_try_at_this_tier:
- op1 = ops_to_try_at_this_tier.pop()
- op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier)
- if op2 is not None:
- # Will try to add the result of this again at this same tier.
- new_operator = adder.add(op1, op2, operator_name)
- ops_to_try_at_this_tier.append(new_operator)
- else:
- ops_to_try_at_next_tier.append(op1)
-
- return ops_to_try_at_next_tier
-
-
-def _pop_a_match_at_tier(op1, operator_list, tier):
- # Search from the back of list to the front in order to create nice default
- # order of operations.
- for i in range(1, len(operator_list) + 1):
- op2 = operator_list[-i]
- for adder in tier:
- if adder.can_add(op1, op2):
- return operator_list.pop(-i), adder
- return None, None
-
-
-def _infer_hints_allowing_override(op1, op2, hints):
- """Infer hints from op1 and op2. hints argument is an override.
-
- Args:
- op1: LinearOperator
- op2: LinearOperator
- hints: _Hints object holding "is_X" boolean hints to use for returned
- operator.
- If some hint is None, try to set using op1 and op2. If the
- hint is provided, ignore op1 and op2 hints. This allows an override
- of previous hints, but does not allow forbidden hints (e.g. you still
- cannot say a real diagonal operator is not self-adjoint.
-
- Returns:
- _Hints object.
- """
- hints = hints or _Hints()
- # If A, B are self-adjoint, then so is A + B.
- if hints.is_self_adjoint is None:
- is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint
- else:
- is_self_adjoint = hints.is_self_adjoint
-
- # If A, B are positive definite, then so is A + B.
- if hints.is_positive_definite is None:
- is_positive_definite = op1.is_positive_definite and op2.is_positive_definite
- else:
- is_positive_definite = hints.is_positive_definite
-
- # A positive definite operator is always non-singular.
- if is_positive_definite and hints.is_positive_definite is None:
- is_non_singular = True
- else:
- is_non_singular = hints.is_non_singular
-
- return _Hints(
- is_non_singular=is_non_singular,
- is_self_adjoint=is_self_adjoint,
- is_positive_definite=is_positive_definite)
-
-
-def _static_check_for_same_dimensions(operators):
- """ValueError if operators determined to have different dimensions."""
- if len(operators) < 2:
- return
-
- domain_dimensions = [(op.name, op.domain_dimension.value) for op in operators
- if op.domain_dimension.value is not None]
- if len(set(value for name, value in domain_dimensions)) > 1:
- raise ValueError("Operators must have the same domain dimension. Found: %s"
- % domain_dimensions)
-
- range_dimensions = [(op.name, op.range_dimension.value) for op in operators
- if op.range_dimension.value is not None]
- if len(set(value for name, value in range_dimensions)) > 1:
- raise ValueError("Operators must have the same range dimension. Found: %s" %
- range_dimensions)
-
-
-def _static_check_for_broadcastable_batch_shape(operators):
- """ValueError if operators determined to have non-broadcastable shapes."""
- if len(operators) < 2:
- return
-
- # This will fail if they cannot be broadcast together.
- batch_shape = operators[0].batch_shape
- for op in operators[1:]:
- batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape)
-
-
-class _Hints(object):
- """Holds 'is_X' flags that every LinearOperator is initialized with."""
-
- def __init__(self,
- is_non_singular=None,
- is_positive_definite=None,
- is_self_adjoint=None):
- self.is_non_singular = is_non_singular
- self.is_positive_definite = is_positive_definite
- self.is_self_adjoint = is_self_adjoint
-
-
-################################################################################
-# Classes to add two linear operators.
-################################################################################
-
-
-@six.add_metaclass(abc.ABCMeta)
-class _Adder(object):
- """Abstract base class to add two operators.
-
- Each `Adder` acts independently, adding everything it can, paying no attention
- as to whether another `Adder` could have done the addition more efficiently.
- """
-
- @property
- def name(self):
- return self.__class__.__name__
-
- @abc.abstractmethod
- def can_add(self, op1, op2):
- """Returns `True` if this `Adder` can add `op1` and `op2`. Else `False`."""
- pass
-
- @abc.abstractmethod
- def _add(self, op1, op2, operator_name, hints):
- # Derived classes can assume op1 and op2 have been validated, e.g. they have
- # the same dtype, and their domain/range dimensions match.
- pass
-
- def add(self, op1, op2, operator_name, hints=None):
- """Return new `LinearOperator` acting like `op1 + op2`.
-
- Args:
- op1: `LinearOperator`
- op2: `LinearOperator`, with `shape` and `dtype` such that adding to
- `op1` is allowed.
- operator_name: `String` name to give to returned `LinearOperator`
- hints: `_Hints` object. Returned `LinearOperator` will be created with
- these hints.
-
- Returns:
- `LinearOperator`
- """
- updated_hints = _infer_hints_allowing_override(op1, op2, hints)
-
- if operator_name is None:
- operator_name = "Add/" + op1.name + "__" + op2.name + "/"
-
- values = op1.graph_parents + op2.graph_parents
- scope_name = self.name
- if scope_name.startswith("_"):
- scope_name = scope_name[1:]
- with ops.name_scope(scope_name, values=values):
- return self._add(op1, op2, operator_name, updated_hints)
-
-
-class _AddAndReturnScaledIdentity(_Adder):
- """Handles additions resulting in an Identity family member.
-
- The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family
- is closed under addition. This `Adder` respects that, and returns an Identity
- """
-
- def can_add(self, op1, op2):
- types = {_type(op1), _type(op2)}
- return not types.difference(_IDENTITY_FAMILY)
-
- def _add(self, op1, op2, operator_name, hints):
- # Will build a LinearOperatorScaledIdentity.
-
- if _type(op1) == _SCALED_IDENTITY:
- multiplier_1 = op1.multiplier
- else:
- multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype)
-
- if _type(op2) == _SCALED_IDENTITY:
- multiplier_2 = op2.multiplier
- else:
- multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype)
-
- return linear_operator_identity.LinearOperatorScaledIdentity(
- num_rows=op1.range_dimension_tensor(),
- multiplier=multiplier_1 + multiplier_2,
- is_non_singular=hints.is_non_singular,
- is_self_adjoint=hints.is_self_adjoint,
- is_positive_definite=hints.is_positive_definite,
- name=operator_name)
-
-
-class _AddAndReturnDiag(_Adder):
- """Handles additions resulting in a Diag operator."""
-
- def can_add(self, op1, op2):
- types = {_type(op1), _type(op2)}
- return not types.difference(_DIAG_LIKE)
-
- def _add(self, op1, op2, operator_name, hints):
- return linear_operator_diag.LinearOperatorDiag(
- diag=op1.diag_part() + op2.diag_part(),
- is_non_singular=hints.is_non_singular,
- is_self_adjoint=hints.is_self_adjoint,
- is_positive_definite=hints.is_positive_definite,
- name=operator_name)
-
-
-class _AddAndReturnTriL(_Adder):
- """Handles additions resulting in a TriL operator."""
-
- def can_add(self, op1, op2):
- types = {_type(op1), _type(op2)}
- return not types.difference(_DIAG_LIKE.union({_TRIL}))
-
- def _add(self, op1, op2, operator_name, hints):
- if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
- op_add_to_tensor, op_other = op1, op2
- else:
- op_add_to_tensor, op_other = op2, op1
-
- return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
- tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
- is_non_singular=hints.is_non_singular,
- is_self_adjoint=hints.is_self_adjoint,
- is_positive_definite=hints.is_positive_definite,
- name=operator_name)
-
-
-class _AddAndReturnMatrix(_Adder):
- """"Handles additions resulting in a `LinearOperatorFullMatrix`."""
-
- def can_add(self, op1, op2): # pylint: disable=unused-argument
- return isinstance(op1, linear_operator.LinearOperator) and isinstance(
- op2, linear_operator.LinearOperator)
-
- def _add(self, op1, op2, operator_name, hints):
- if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
- op_add_to_tensor, op_other = op1, op2
- else:
- op_add_to_tensor, op_other = op2, op1
- return linear_operator_full_matrix.LinearOperatorFullMatrix(
- matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
- is_non_singular=hints.is_non_singular,
- is_self_adjoint=hints.is_self_adjoint,
- is_positive_definite=hints.is_positive_definite,
- name=operator_name)
-
-
-################################################################################
-# Constants designating types of LinearOperators
-################################################################################
-
-# Type name constants for LinearOperator classes.
-_IDENTITY = "identity"
-_SCALED_IDENTITY = "scaled_identity"
-_DIAG = "diag"
-_TRIL = "tril"
-_MATRIX = "matrix"
-
-# Groups of operators.
-_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY}
-_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY}
-# operators with an efficient .add_to_tensor() method.
-_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE
-
-
-def _type(operator):
- """Returns the type name constant (e.g. _TRIL) for operator."""
- if isinstance(operator, linear_operator_diag.LinearOperatorDiag):
- return _DIAG
- if isinstance(operator,
- linear_operator_lower_triangular.LinearOperatorLowerTriangular):
- return _TRIL
- if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix):
- return _MATRIX
- if isinstance(operator, linear_operator_identity.LinearOperatorIdentity):
- return _IDENTITY
- if isinstance(operator,
- linear_operator_identity.LinearOperatorScaledIdentity):
- return _SCALED_IDENTITY
- raise TypeError("Operator type unknown: %s" % operator)
-
-
-################################################################################
-# Addition tiers:
-# We attempt to use Adders in tier K before K+1.
-#
-# Organize tiers to
-# (i) reduce O(..) complexity of forming final operator, and
-# (ii) produce the "most efficient" final operator.
-# Dev notes:
-# * Results of addition at tier K will be added at tier K or higher.
-# * Tiers may change, and we warn the user that it may change.
-################################################################################
-
-# Note that the final tier, _AddAndReturnMatrix, will convert everything to a
-# dense matrix. So it is sometimes very inefficient.
-_DEFAULT_ADDITION_TIERS = [
- [_AddAndReturnScaledIdentity()],
- [_AddAndReturnDiag()],
- [_AddAndReturnTriL()],
- [_AddAndReturnMatrix()],
-]
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index 1d2db1cec8..8466dc36d1 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -125,7 +125,7 @@ def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero):
],
example_ids=[str(i) for i in range(num_examples)])
- weights = variables_lib.Variable(
+ weights = variables_lib.VariableV1(
array_ops.zeros([dim], dtype=dtypes.float32))
variables_dict = dict(
sparse_features_weights=[weights],
@@ -134,7 +134,7 @@ def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero):
return examples_dict, variables_dict
-def make_variable_dict(max_age, max_gender, partitioned=False):
+def make_variable_dict(max_age, max_gender, num_shards=None, partitioned=False):
# TODO(sibyl-toe9oF2e): Figure out how to derive max_age & max_gender from
# examples_dict.
partitioner = None
@@ -142,14 +142,15 @@ def make_variable_dict(max_age, max_gender, partitioned=False):
partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2,
axis=0)
with variable_scope.variable_scope(
- name_or_scope='variables',
+ name_or_scope=('variables/shard_{}'.format(num_shards)
+ if num_shards else 'variables'),
partitioner=partitioner):
- age_weights = variables_lib.Variable(
- array_ops.zeros(
- [max_age + 1], dtype=dtypes.float32))
- gender_weights = variables_lib.Variable(
- array_ops.zeros(
- [max_gender + 1], dtype=dtypes.float32))
+ age_weights = variable_scope.get_variable(
+ name='age',
+ initializer=array_ops.zeros([max_age + 1], dtype=dtypes.float32))
+ gender_weights = variable_scope.get_variable(
+ name='gender',
+ initializer=array_ops.zeros([max_gender + 1], dtype=dtypes.float32))
return dict(
sparse_features_weights=[age_weights, gender_weights],
dense_features_weights=[])
@@ -183,7 +184,7 @@ def make_dense_examples_and_variables_dicts(dense_features_values, weights,
dense_tensors.append(dense_tensor)
# Add variables of shape [feature_column_dimension].
dense_weights.append(
- variables_lib.Variable(
+ variables_lib.VariableV1(
array_ops.zeros(
[dense_tensor.get_shape().as_list()[1]], dtype=dtypes.float32)))
@@ -242,7 +243,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -290,7 +291,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1, partitioned=True)
+ variables = make_variable_dict(1, 1, num_shards, partitioned=True)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -322,6 +323,68 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
self.assertAllClose(
0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)
+ def testSomePartitionedPrimals(self):
+ # Setup test data
+ example_protos = [
+ make_example_proto({
+ 'age': [0],
+ 'gender': [0]
+ }, 0),
+ make_example_proto({
+ 'age': [0],
+ 'gender': [1]
+ }, 1),
+ ]
+ example_weights = [1.0, 1.0]
+ for num_shards in _SHARD_NUMBERS:
+ with self._single_threaded_test_session():
+ examples = make_example_dict(example_protos, example_weights)
+ # Explicitly make age a [1]-shaped Variable (which cannot be
+ # partitioned), while making gender a PartitionedVariable.
+ age_weights = variables_lib.VariableV1(
+ array_ops.zeros([1], dtype=dtypes.float32))
+ with variable_scope.variable_scope(
+ name_or_scope=('variables/shard_{}'.format(num_shards)
+ if num_shards else 'variables'),
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_shards=2, axis=0)):
+ gender_weights = variable_scope.get_variable(
+ name='gender',
+ initializer=array_ops.zeros([2], dtype=dtypes.float32))
+ variables = dict(
+ sparse_features_weights=[age_weights, gender_weights],
+ dense_features_weights=[])
+ options = dict(
+ symmetric_l2_regularization=1,
+ symmetric_l1_regularization=0,
+ num_table_shards=num_shards,
+ loss_type='logistic_loss')
+
+ lr = SdcaModel(examples, variables, options)
+ variables_lib.global_variables_initializer().run()
+ unregularized_loss = lr.unregularized_loss(examples)
+ loss = lr.regularized_loss(examples)
+ predictions = lr.predictions(examples)
+ self.assertAllClose(0.693147, unregularized_loss.eval())
+ self.assertAllClose(0.693147, loss.eval())
+ train_op = lr.minimize()
+ for _ in range(_MAX_ITERATIONS):
+ train_op.run()
+ lr.update_weights(train_op).run()
+ # The high tolerance in unregularized_loss comparisons is due to the
+ # fact that it's possible to trade off unregularized_loss vs.
+ # regularization and still have a sum that is quite close to the
+ # optimal regularized_loss value. SDCA's duality gap only ensures that
+ # the regularized_loss is within 0.01 of optimal.
+ # 0.525457 is the optimal regularized_loss.
+ # 0.593014 is the unregularized_loss at that optimum.
+ self.assertAllClose(0.512591, unregularized_loss.eval(), atol=0.05)
+ self.assertAllClose(0.593014, loss.eval(), atol=0.01)
+ predicted_labels = get_binary_predictions_for_logistic(predictions)
+ self.assertAllEqual([0, 1], predicted_labels.eval())
+ self.assertAllClose(
+ 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)
+
def testSparseRandom(self):
dim = 20
num_examples = 1000
@@ -463,7 +526,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=0,
symmetric_l1_regularization=0,
@@ -521,7 +584,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
with self._single_threaded_test_session():
# Only use examples 0 and 2
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -561,7 +624,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -598,7 +661,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(3, 1)
+ variables = make_variable_dict(3, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -639,7 +702,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -679,7 +742,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -738,7 +801,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
labels=[1.0, 0.0])
# Replace with a variable of size 1 instead of 2.
variables['dense_features_weights'] = [
- variables_lib.Variable(array_ops.zeros(
+ variables_lib.VariableV1(array_ops.zeros(
[1], dtype=dtypes.float32))
]
options = dict(
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index 14f59a3f64..b98adf862b 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -400,14 +400,16 @@ class SdcaModel(object):
sparse_weights = []
sparse_indices = []
- # If we have partitioned variables, keep a few lists of Tensors around
- # that we need for the assign_add after the op call to
- # gen_sdca_ops.sdca_optimizer().
- num_partitions_by_var = []
- p_assignments_by_var = []
- gather_ids_by_var = []
- for w, i in zip(self._slots['unshrinked_sparse_features_weights'],
- sparse_feature_indices):
+ # If we have partitioned variables, keep a few dictionaries of Tensors
+ # around that we need for the assign_add after the op call to
+ # gen_sdca_ops.sdca_optimizer(). These are keyed because we may have a
+ # mix of partitioned and un-partitioned variables.
+ num_partitions_by_var = {}
+ p_assignments_by_var = {}
+ gather_ids_by_var = {}
+ for v_num, (w, i) in enumerate(
+ zip(self._slots['unshrinked_sparse_features_weights'],
+ sparse_feature_indices)):
# Append the sparse_indices (in full-variable space).
sparse_idx = math_ops.cast(
array_ops.unique(math_ops.cast(i, dtypes.int32))[0],
@@ -456,10 +458,10 @@ class SdcaModel(object):
gather_ids = data_flow_ops.dynamic_partition(new_ids,
p_assignments,
num_partitions)
- # Append these to the lists for use in the later update.
- num_partitions_by_var.append(num_partitions)
- p_assignments_by_var.append(p_assignments)
- gather_ids_by_var.append(gather_ids)
+ # Add these into the dictionaries for use in the later update.
+ num_partitions_by_var[v_num] = num_partitions
+ p_assignments_by_var[v_num] = p_assignments
+ gather_ids_by_var[v_num] = gather_ids
# Gather the weights from each partition.
partition_gathered_weights = []
diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md
index a676b705f1..a4b3d83efe 100644
--- a/tensorflow/contrib/lite/README.md
+++ b/tensorflow/contrib/lite/README.md
@@ -4,5 +4,5 @@ TensorFlow Lite is TensorFlow's lightweight solution for mobile and embedded
devices. It enables low-latency inference of on-device machine learning models
with a small binary size and fast performance supporting hardware acceleration.
-See the documentation: https://www.tensorflow.org/mobile/tflite/
-Documentation edits can be made here: [tensorflow/docs_src/mobile/tflite](../../docs_src/mobile/tflite)
+See the documentation: https://www.tensorflow.org/lite/
+Documentation edits can be made here: [tensorflow/contrib/lite/g3doc](./g3doc/)
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 52b994ee92..7ef26de69f 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -294,13 +294,14 @@ def generated_test_models():
#"transpose_conv", # disabled due to b/111213074
"unpack",
"where",
+ "zeros_like",
]
def generated_test_conversion_modes():
"""Returns a list of conversion modes."""
# TODO(nupurgarg): Add "pb2lite" when it's in open source. b/113614050.
- return ["toco-extended", ""]
+ return ["toco-flex", ""]
def generated_test_models_all():
"""Generates a list of all tests with the different converters.
@@ -334,7 +335,7 @@ def gen_zip_test(name, test_name, conversion_mode, **kwargs):
# TODO(nupurgarg): Comment in when pb2lite is in open source. b/113614050.
# if conversion_mode == "pb2lite":
# toco = "//tensorflow/contrib/lite/experimental/pb2lite:pb2lite"
- flags = "--ignore_toco_errors --run_with_extended"
+ flags = "--ignore_toco_errors --run_with_flex"
kwargs["tags"].append("skip_already_failing")
kwargs["tags"].append("no_oss")
kwargs["tags"].append("notap")
@@ -390,3 +391,41 @@ def gen_selected_ops(name, model):
(tool, model, out, tflite_path[2:]),
tools = [tool],
)
+
+def gen_full_model_test(conversion_modes, models, data, test_suite_tag):
+ """Generates Python test targets for testing TFLite models.
+
+ Args:
+ conversion_modes: List of conversion modes to test the models on.
+ models: List of models to test.
+ data: List of BUILD targets linking the data.
+ test_suite_tag: Tag identifying the model test suite.
+ """
+ options = [
+ (conversion_mode, model)
+ for model in models
+ for conversion_mode in conversion_modes
+ ]
+
+ for conversion_mode, model_name in options:
+ native.py_test(
+ name = "model_coverage_test_%s_%s" % (model_name, conversion_mode.lower()),
+ srcs = ["model_coverage_test.py"],
+ main = "model_coverage_test.py",
+ args = [
+ "--model_name=%s" % model_name,
+ "--converter_mode=%s" % conversion_mode,
+ ],
+ data = data,
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_windows",
+ "notap",
+ ] + [test_suite_tag],
+ deps = [
+ "//tensorflow/contrib/lite/testing:model_coverage_lib",
+ "//tensorflow/contrib/lite/python:lite",
+ "//tensorflow/python:client_testlib",
+ ],
+ )
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index 5e97b777fc..7809d114e2 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -118,6 +118,8 @@ typedef enum {
kTfLiteBuiltinFloorDiv = 90,
kTfLiteBuiltinReduceAny = 91,
kTfLiteBuiltinSquare = 92,
+ kTfLiteBuiltinZerosLike = 93,
+ kTfLiteBuiltinFill = 94,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/c/c_api_internal.c b/tensorflow/contrib/lite/c/c_api_internal.c
index 1846bad4b7..8a0c177b19 100644
--- a/tensorflow/contrib/lite/c/c_api_internal.c
+++ b/tensorflow/contrib/lite/c/c_api_internal.c
@@ -14,15 +14,29 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#ifndef TF_LITE_STATIC_MEMORY
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
+#endif // TF_LITE_STATIC_MEMORY
int TfLiteIntArrayGetSizeInBytes(int size) {
static TfLiteIntArray dummy;
return sizeof(dummy) + sizeof(dummy.data[0]) * size;
}
+int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b) {
+ if (a == b) return 1;
+ if (a == NULL || b == NULL) return 0;
+ if (a->size != b->size) return 0;
+ int i = 0;
+ for (; i < a->size; i++)
+ if (a->data[i] != b->data[i]) return 0;
+ return 1;
+}
+
+#ifndef TF_LITE_STATIC_MEMORY
+
TfLiteIntArray* TfLiteIntArrayCreate(int size) {
TfLiteIntArray* ret =
(TfLiteIntArray*)malloc(TfLiteIntArrayGetSizeInBytes(size));
@@ -40,16 +54,6 @@ void TfLiteIntArrayPrint(const char* s, TfLiteIntArray* a) {
printf("]\n");
}
-int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b) {
- if (a == b) return 1;
- if (a == NULL || b == NULL) return 0;
- if (a->size != b->size) return 0;
- int i = 0;
- for (; i < a->size; i++)
- if (a->data[i] != b->data[i]) return 0;
- return 1;
-}
-
TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src) {
if (!src) return NULL;
TfLiteIntArray* ret = TfLiteIntArrayCreate(src->size);
@@ -102,3 +106,4 @@ void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) {
}
tensor->bytes = num_bytes;
}
+#endif // TF_LITE_STATIC_MEMORY
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
index f4d2839b1b..e6900e0950 100644
--- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
@@ -44,16 +44,6 @@ void FlatBufferIntVectorToArray(int max_size_of_buffer,
}
}
-// Allocate a structure using malloc, but make sure the structure is a POD
-// structure that doesn't require constructors to run. The reason we do this,
-// is that Interpreter's C extension part will take ownership so destructors
-// will not be run during deallocation.
-template <class T>
-T* MallocPOD() {
- static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
- return static_cast<T*>(malloc(sizeof(T)));
-}
-
} // namespace
TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
@@ -98,7 +88,8 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
// need to be released by calling `free`.`
// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
- ErrorReporter* error_reporter, void** builtin_data) {
+ ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator, void** builtin_data) {
auto parse_padding = [](Padding padding) {
switch (padding) {
case Padding_SAME:
@@ -150,7 +141,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = nullptr;
switch (op_type) {
case BuiltinOperator_CONV_2D: {
- TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
+ TfLiteConvParams* params = allocator->AllocatePOD<TfLiteConvParams>();
if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
params->padding = parse_padding(conv_params->padding());
params->stride_width = conv_params->stride_w();
@@ -165,7 +156,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_CAST: {
- TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
+ TfLiteCastParams* params = allocator->AllocatePOD<TfLiteCastParams>();
if (auto* schema_params = op->builtin_options_as_CastOptions()) {
auto in_status =
ConvertTensorType(schema_params->in_data_type(),
@@ -174,7 +165,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
ConvertTensorType(schema_params->out_data_type(),
&params->out_data_type, error_reporter);
if (in_status != kTfLiteOk || out_status != kTfLiteOk) {
- free(params);
+ allocator->Deallocate(params);
return kTfLiteError;
}
}
@@ -183,7 +174,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_LSH_PROJECTION: {
TfLiteLSHProjectionParams* params =
- MallocPOD<TfLiteLSHProjectionParams>();
+ allocator->AllocatePOD<TfLiteLSHProjectionParams>();
if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) {
params->type = parseLSHProjectionType(lshParams->type());
}
@@ -193,7 +184,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_AVERAGE_POOL_2D:
case BuiltinOperator_MAX_POOL_2D:
case BuiltinOperator_L2_POOL_2D: {
- TfLitePoolParams* params = MallocPOD<TfLitePoolParams>();
+ TfLitePoolParams* params = allocator->AllocatePOD<TfLitePoolParams>();
if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
params->padding = parse_padding(pool_params->padding());
params->stride_width = pool_params->stride_w();
@@ -208,7 +199,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_DEPTHWISE_CONV_2D: {
TfLiteDepthwiseConvParams* params =
- MallocPOD<TfLiteDepthwiseConvParams>();
+ allocator->AllocatePOD<TfLiteDepthwiseConvParams>();
if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) {
params->padding = parse_padding(conv_params->padding());
params->stride_width = conv_params->stride_w();
@@ -224,7 +215,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_SVDF: {
- TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>();
+ TfLiteSVDFParams* params = allocator->AllocatePOD<TfLiteSVDFParams>();
if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) {
params->rank = svdf_params->rank();
params->activation =
@@ -235,7 +226,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
- TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>();
+ TfLiteSequenceRNNParams* params =
+ allocator->AllocatePOD<TfLiteSequenceRNNParams>();
if (auto* sequence_rnn_params =
op->builtin_options_as_SequenceRNNOptions()) {
params->activation =
@@ -246,7 +238,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_RNN: {
- TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
+ TfLiteRNNParams* params = allocator->AllocatePOD<TfLiteRNNParams>();
if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
params->activation =
parse_activation(rnn_params->fused_activation_function());
@@ -256,7 +248,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
TfLiteEmbeddingLookupSparseParams* params =
- MallocPOD<TfLiteEmbeddingLookupSparseParams>();
+ allocator->AllocatePOD<TfLiteEmbeddingLookupSparseParams>();
if (auto* embedding_params =
op->builtin_options_as_EmbeddingLookupSparseOptions()) {
params->combiner = parseCombinerType(embedding_params->combiner());
@@ -266,7 +258,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_FULLY_CONNECTED: {
TfLiteFullyConnectedParams* params =
- MallocPOD<TfLiteFullyConnectedParams>();
+ allocator->AllocatePOD<TfLiteFullyConnectedParams>();
if (auto* fully_connected_params =
op->builtin_options_as_FullyConnectedOptions()) {
params->activation = parse_activation(
@@ -291,7 +283,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
// no-op.
break;
case BuiltinOperator_SOFTMAX: {
- TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>();
+ TfLiteSoftmaxParams* params =
+ allocator->AllocatePOD<TfLiteSoftmaxParams>();
if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) {
params->beta = softmax_params->beta();
}
@@ -300,7 +293,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_CONCATENATION: {
TfLiteConcatenationParams* params =
- MallocPOD<TfLiteConcatenationParams>();
+ allocator->AllocatePOD<TfLiteConcatenationParams>();
if (auto* concatenation_params =
op->builtin_options_as_ConcatenationOptions()) {
params->activation =
@@ -311,7 +304,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_MUL: {
- auto* params = MallocPOD<TfLiteMulParams>();
+ auto* params = allocator->AllocatePOD<TfLiteMulParams>();
if (auto* schema_params = op->builtin_options_as_MulOptions()) {
params->activation =
parse_activation(schema_params->fused_activation_function());
@@ -320,7 +313,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_ADD: {
- auto* params = MallocPOD<TfLiteAddParams>();
+ auto* params = allocator->AllocatePOD<TfLiteAddParams>();
if (auto* schema_params = op->builtin_options_as_AddOptions()) {
params->activation =
parse_activation(schema_params->fused_activation_function());
@@ -329,7 +322,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_DIV: {
- auto* params = MallocPOD<TfLiteDivParams>();
+ auto* params = allocator->AllocatePOD<TfLiteDivParams>();
if (auto* schema_params = op->builtin_options_as_DivOptions()) {
params->activation =
parse_activation(schema_params->fused_activation_function());
@@ -338,7 +331,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_SUB: {
- auto* params = MallocPOD<TfLiteSubParams>();
+ auto* params = allocator->AllocatePOD<TfLiteSubParams>();
if (auto* schema_params = op->builtin_options_as_SubOptions()) {
params->activation =
parse_activation(schema_params->fused_activation_function());
@@ -347,7 +340,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_L2_NORMALIZATION: {
- auto* params = MallocPOD<TfLiteL2NormParams>();
+ auto* params = allocator->AllocatePOD<TfLiteL2NormParams>();
if (auto* schema_params = op->builtin_options_as_L2NormOptions()) {
params->activation =
parse_activation(schema_params->fused_activation_function());
@@ -356,7 +349,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
- auto* params = MallocPOD<TfLiteLocalResponseNormParams>();
+ auto* params = allocator->AllocatePOD<TfLiteLocalResponseNormParams>();
if (auto* schema_params =
op->builtin_options_as_LocalResponseNormalizationOptions()) {
params->radius = schema_params->radius();
@@ -370,7 +363,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
case BuiltinOperator_LSTM: {
- TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>();
+ TfLiteLSTMParams* params = allocator->AllocatePOD<TfLiteLSTMParams>();
if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
params->activation =
parse_activation(lstm_params->fused_activation_function());
@@ -389,7 +382,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_RESIZE_BILINEAR: {
- auto* params = MallocPOD<TfLiteResizeBilinearParams>();
+ auto* params = allocator->AllocatePOD<TfLiteResizeBilinearParams>();
if (auto* schema_params =
op->builtin_options_as_ResizeBilinearOptions()) {
params->align_corners = schema_params->align_corners();
@@ -398,7 +391,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_RESHAPE: {
- auto* params = MallocPOD<TfLiteReshapeParams>();
+ auto* params = allocator->AllocatePOD<TfLiteReshapeParams>();
if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
auto* new_shape = schema_params->new_shape();
FlatBufferIntVectorToArray(sizeof(params->shape), new_shape,
@@ -409,7 +402,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_SKIP_GRAM: {
- TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>();
+ TfLiteSkipGramParams* params =
+ allocator->AllocatePOD<TfLiteSkipGramParams>();
if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) {
params->ngram_size = skip_gram_params->ngram_size();
params->max_skip_size = skip_gram_params->max_skip_size();
@@ -419,7 +413,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_SPACE_TO_DEPTH: {
- auto* params = MallocPOD<TfLiteSpaceToDepthParams>();
+ auto* params = allocator->AllocatePOD<TfLiteSpaceToDepthParams>();
if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) {
params->block_size = schema_params->block_size();
}
@@ -427,7 +421,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_GATHER: {
- TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>();
+ TfLiteGatherParams* params = allocator->AllocatePOD<TfLiteGatherParams>();
params->axis = 0;
if (auto* gather_params = op->builtin_options_as_GatherOptions()) {
params->axis = gather_params->axis();
@@ -442,7 +436,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_REDUCE_PROD:
case BuiltinOperator_REDUCE_ANY:
case BuiltinOperator_SUM: {
- auto* params = MallocPOD<TfLiteReducerParams>();
+ auto* params = allocator->AllocatePOD<TfLiteReducerParams>();
if (auto* schema_params = op->builtin_options_as_ReducerOptions()) {
params->keep_dims = schema_params->keep_dims();
}
@@ -450,7 +444,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_SPLIT: {
- auto* params = MallocPOD<TfLiteSplitParams>();
+ auto* params = allocator->AllocatePOD<TfLiteSplitParams>();
if (auto* schema_params = op->builtin_options_as_SplitOptions()) {
params->num_splits = schema_params->num_splits();
}
@@ -458,7 +452,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_SQUEEZE: {
- auto* params = MallocPOD<TfLiteSqueezeParams>();
+ auto* params = allocator->AllocatePOD<TfLiteSqueezeParams>();
if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) {
const auto& squeeze_dims = schema_params->squeeze_dims();
FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims,
@@ -469,7 +463,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_STRIDED_SLICE: {
- auto* params = MallocPOD<TfLiteStridedSliceParams>();
+ auto* params = allocator->AllocatePOD<TfLiteStridedSliceParams>();
if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) {
params->begin_mask = schema_params->begin_mask();
params->end_mask = schema_params->end_mask();
@@ -481,7 +475,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_ARG_MAX: {
- auto* params = MallocPOD<TfLiteArgMaxParams>();
+ auto* params = allocator->AllocatePOD<TfLiteArgMaxParams>();
if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) {
ConvertTensorType(schema_params->output_type(), &params->output_type,
error_reporter);
@@ -490,7 +484,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_ARG_MIN: {
- auto* params = MallocPOD<TfLiteArgMinParams>();
+ auto* params = allocator->AllocatePOD<TfLiteArgMinParams>();
if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) {
ConvertTensorType(schema_params->output_type(), &params->output_type,
error_reporter);
@@ -500,7 +494,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_TRANSPOSE_CONV: {
TfLiteTransposeConvParams* params =
- MallocPOD<TfLiteTransposeConvParams>();
+ allocator->AllocatePOD<TfLiteTransposeConvParams>();
if (auto* transpose_conv_params =
op->builtin_options_as_TransposeConvOptions()) {
params->padding = parse_padding(transpose_conv_params->padding());
@@ -512,7 +506,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_SPARSE_TO_DENSE: {
TfLiteSparseToDenseParams* params =
- MallocPOD<TfLiteSparseToDenseParams>();
+ allocator->AllocatePOD<TfLiteSparseToDenseParams>();
if (auto* sparse_to_dense_params =
op->builtin_options_as_SparseToDenseOptions()) {
params->validate_indices = sparse_to_dense_params->validate_indices();
@@ -521,7 +515,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_SHAPE: {
- auto* params = MallocPOD<TfLiteShapeParams>();
+ auto* params = allocator->AllocatePOD<TfLiteShapeParams>();
if (auto* schema_params = op->builtin_options_as_ShapeOptions()) {
ConvertTensorType(schema_params->out_type(), &params->out_type,
error_reporter);
@@ -530,7 +524,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_PACK: {
- TfLitePackParams* params = MallocPOD<TfLitePackParams>();
+ TfLitePackParams* params = allocator->AllocatePOD<TfLitePackParams>();
if (auto* pack_params = op->builtin_options_as_PackOptions()) {
params->values_count = pack_params->values_count();
params->axis = pack_params->axis();
@@ -544,7 +538,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
return kTfLiteError;
}
case BuiltinOperator_FAKE_QUANT: {
- auto* params = MallocPOD<TfLiteFakeQuantParams>();
+ auto* params = allocator->AllocatePOD<TfLiteFakeQuantParams>();
if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) {
params->min = schema_params->min();
params->max = schema_params->max();
@@ -555,7 +549,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_ONE_HOT: {
- auto* params = MallocPOD<TfLiteOneHotParams>();
+ auto* params = allocator->AllocatePOD<TfLiteOneHotParams>();
if (auto* schema_params = op->builtin_options_as_OneHotOptions()) {
params->axis = schema_params->axis();
}
@@ -563,7 +557,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_UNPACK: {
- TfLiteUnpackParams* params = MallocPOD<TfLiteUnpackParams>();
+ TfLiteUnpackParams* params = allocator->AllocatePOD<TfLiteUnpackParams>();
if (auto* unpack_params = op->builtin_options_as_UnpackOptions()) {
params->num = unpack_params->num();
params->axis = unpack_params->axis();
@@ -618,6 +612,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_LOGICAL_NOT:
case BuiltinOperator_FLOOR_DIV:
case BuiltinOperator_SQUARE:
+ case BuiltinOperator_ZEROS_LIKE:
+ case BuiltinOperator_FILL:
break;
}
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h
index 4dec6f9cfc..c770e627fd 100644
--- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h
@@ -26,6 +26,25 @@ limitations under the License.
namespace tflite {
+// Interface class for builtin data allocations.
+class BuiltinDataAllocator {
+ public:
+ virtual void* Allocate(size_t size) = 0;
+ virtual void Deallocate(void* data) = 0;
+
+ // Allocate a structure, but make sure it is a POD structure that doesn't
+ // require constructors to run. The reason we do this, is that Interpreter's C
+ // extension part will take ownership so destructors will not be run during
+ // deallocation.
+ template <typename T>
+ T* AllocatePOD() {
+ static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
+ return static_cast<T*>(this->Allocate(sizeof(T)));
+ }
+
+ virtual ~BuiltinDataAllocator() {}
+};
+
// Parse the appropriate data out of the op.
//
// This handles builtin data explicitly as there are flatbuffer schemas.
@@ -36,7 +55,8 @@ namespace tflite {
// function's responsibility to free it.
// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
- ErrorReporter* error_reporter, void** builtin_data);
+ ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator, void** builtin_data);
// Converts the tensor data type used in the flat buffer to the representation
// used by the runtime.
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc
index b12bdf43b2..8ae94e1d33 100644
--- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc
@@ -39,11 +39,31 @@ class MockErrorReporter : public ErrorReporter {
int buffer_size_;
};
+// Used to determine how the op data parsing function creates its working space.
+class MockDataAllocator : public BuiltinDataAllocator {
+ public:
+ MockDataAllocator() : is_allocated_(false) {}
+ void* Allocate(size_t size) override {
+ EXPECT_FALSE(is_allocated_);
+ const int max_size = kBufferSize;
+ EXPECT_LE(size, max_size);
+ is_allocated_ = true;
+ return buffer_;
+ }
+ void Deallocate(void* data) override { is_allocated_ = false; }
+
+ private:
+ static constexpr int kBufferSize = 1024;
+ char buffer_[kBufferSize];
+ bool is_allocated_;
+};
+
} // namespace
TEST(FlatbufferConversions, TestParseOpDataConv) {
MockErrorReporter mock_reporter;
ErrorReporter* reporter = &mock_reporter;
+ MockDataAllocator mock_allocator;
flatbuffers::FlatBufferBuilder builder;
flatbuffers::Offset<void> conv_options =
@@ -58,7 +78,7 @@ TEST(FlatbufferConversions, TestParseOpDataConv) {
const Operator* conv_op = flatbuffers::GetRoot<Operator>(conv_pointer);
void* output_data = nullptr;
EXPECT_EQ(kTfLiteOk, ParseOpData(conv_op, BuiltinOperator_CONV_2D, reporter,
- &output_data));
+ &mock_allocator, &output_data));
EXPECT_NE(nullptr, output_data);
TfLiteConvParams* params = reinterpret_cast<TfLiteConvParams*>(output_data);
EXPECT_EQ(kTfLitePaddingSame, params->padding);
@@ -67,12 +87,12 @@ TEST(FlatbufferConversions, TestParseOpDataConv) {
EXPECT_EQ(kTfLiteActRelu, params->activation);
EXPECT_EQ(3, params->dilation_width_factor);
EXPECT_EQ(4, params->dilation_height_factor);
- free(output_data);
}
TEST(FlatbufferConversions, TestParseOpDataCustom) {
MockErrorReporter mock_reporter;
ErrorReporter* reporter = &mock_reporter;
+ MockDataAllocator mock_allocator;
flatbuffers::FlatBufferBuilder builder;
flatbuffers::Offset<void> null_options;
@@ -84,7 +104,7 @@ TEST(FlatbufferConversions, TestParseOpDataCustom) {
const Operator* custom_op = flatbuffers::GetRoot<Operator>(custom_pointer);
void* output_data = nullptr;
EXPECT_EQ(kTfLiteOk, ParseOpData(custom_op, BuiltinOperator_CUSTOM, reporter,
- &output_data));
+ &mock_allocator, &output_data));
EXPECT_EQ(nullptr, output_data);
}
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/flex/BUILD
index bf5d91899c..bf5d91899c 100644
--- a/tensorflow/contrib/lite/delegates/eager/BUILD
+++ b/tensorflow/contrib/lite/delegates/flex/BUILD
diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.cc b/tensorflow/contrib/lite/delegates/flex/buffer_map.cc
index e5a19c3997..63e39196d9 100644
--- a/tensorflow/contrib/lite/delegates/eager/buffer_map.cc
+++ b/tensorflow/contrib/lite/delegates/flex/buffer_map.cc
@@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h"
+#include "tensorflow/contrib/lite/delegates/flex/buffer_map.h"
#include "tensorflow/c/c_api_internal.h"
-#include "tensorflow/contrib/lite/delegates/eager/util.h"
+#include "tensorflow/contrib/lite/delegates/flex/util.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/log_memory.h"
namespace tflite {
-namespace eager {
+namespace flex {
namespace {
// A tensor buffer that is allocated, deallocated and populated by TF Lite.
class TfLiteTensorBuffer : public tensorflow::TensorBuffer {
@@ -107,5 +107,5 @@ void BufferMap::SetFromTensorFlow(int tensor_index, tensorflow::Tensor tensor) {
id_to_tensor_[tensor_index] = std::move(tensor);
}
-} // namespace eager
+} // namespace flex
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.h b/tensorflow/contrib/lite/delegates/flex/buffer_map.h
index aaaa045840..4ce886568a 100644
--- a/tensorflow/contrib/lite/delegates/eager/buffer_map.h
+++ b/tensorflow/contrib/lite/delegates/flex/buffer_map.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_BUFFER_MAP_H_
-#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_BUFFER_MAP_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_BUFFER_MAP_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_BUFFER_MAP_H_
#include <map>
@@ -21,12 +21,12 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
namespace tflite {
-namespace eager {
+namespace flex {
// Maps a TF Lite tensor index into a TensorFlow tensor.
//
// The TF Lite interpreter assigns integer indices to each of its tensors, but
-// the Eager delegate deals in terms of TensorFlow tensors. This class maps
+// the Flex delegate deals in terms of TensorFlow tensors. This class maps
// from indices to tensors and allows the creation of new tensors to be
// associated with a given index.
class BufferMap {
@@ -55,7 +55,7 @@ class BufferMap {
std::map<int, tensorflow::Tensor> id_to_tensor_;
};
-} // namespace eager
+} // namespace flex
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_BUFFER_MAP_H_
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_BUFFER_MAP_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc b/tensorflow/contrib/lite/delegates/flex/buffer_map_test.cc
index a046943e56..bb80e25e80 100644
--- a/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc
+++ b/tensorflow/contrib/lite/delegates/flex/buffer_map_test.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h"
+#include "tensorflow/contrib/lite/delegates/flex/buffer_map.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/util.h"
namespace tflite {
-namespace eager {
+namespace flex {
namespace {
using ::testing::ElementsAre;
@@ -164,7 +164,7 @@ TEST(BufferMapTest, TensorFlowOverwritesTfLite) {
}
} // namespace
-} // namespace eager
+} // namespace flex
} // namespace tflite
int main(int argc, char** argv) {
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.cc b/tensorflow/contrib/lite/delegates/flex/delegate.cc
index 45fc158157..ba065a8ff5 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate.cc
+++ b/tensorflow/contrib/lite/delegates/flex/delegate.cc
@@ -12,19 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
+#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
#include <vector>
#include "tensorflow/contrib/lite/context_util.h"
-#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h"
-#include "tensorflow/contrib/lite/delegates/eager/kernel.h"
-#include "tensorflow/contrib/lite/delegates/eager/util.h"
+#include "tensorflow/contrib/lite/delegates/flex/buffer_map.h"
+#include "tensorflow/contrib/lite/delegates/flex/kernel.h"
+#include "tensorflow/contrib/lite/delegates/flex/util.h"
#include "tensorflow/contrib/lite/util.h"
#include "tensorflow/core/lib/core/status.h"
namespace tflite {
-namespace eager {
+namespace flex {
namespace delegate {
TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
@@ -32,7 +32,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
TfLiteIntArray* plan;
TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan));
- // Add all custom ops starting with "Eager" to list of supported nodes.
+ // Add all custom ops starting with "Flex" to list of supported nodes.
std::vector<int> supported_nodes;
for (int node_index : TfLiteIntArrayView(plan)) {
TfLiteNode* node;
@@ -40,7 +40,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration(
context, node_index, &node, &registration));
- if (IsEagerOp(registration->custom_name)) {
+ if (IsFlexOp(registration->custom_name)) {
supported_nodes.push_back(node_index);
}
}
@@ -81,28 +81,28 @@ TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
}
} // namespace delegate
-} // namespace eager
+} // namespace flex
-std::unique_ptr<EagerDelegate> EagerDelegate::Create() {
- std::unique_ptr<eager::DelegateData> delegate_data;
- if (!eager::DelegateData::Create(&delegate_data).ok()) {
+std::unique_ptr<FlexDelegate> FlexDelegate::Create() {
+ std::unique_ptr<flex::DelegateData> delegate_data;
+ if (!flex::DelegateData::Create(&delegate_data).ok()) {
fprintf(stderr, "Unable to initialize TensorFlow context.\n");
return nullptr;
}
- return std::unique_ptr<EagerDelegate>(
- new EagerDelegate(std::move(delegate_data)));
+ return std::unique_ptr<FlexDelegate>(
+ new FlexDelegate(std::move(delegate_data)));
}
-EagerDelegate::EagerDelegate(std::unique_ptr<eager::DelegateData> delegate_data)
+FlexDelegate::FlexDelegate(std::unique_ptr<flex::DelegateData> delegate_data)
: TfLiteDelegate{
/*data_=*/delegate_data.get(),
- /*nullptr,*/ &eager::delegate::Prepare,
- /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle,
+ /*nullptr,*/ &flex::delegate::Prepare,
+ /*CopyFromBufferHandle=*/&flex::delegate::CopyFromBufferHandle,
/*CopyToBufferHandle=*/nullptr,
/*FreeBufferHandle=*/nullptr},
delegate_data_(std::move(delegate_data)) {}
-EagerDelegate::~EagerDelegate() {}
+FlexDelegate::~FlexDelegate() {}
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/flex/delegate.h
index 70f3c15af4..1017780dc7 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate.h
+++ b/tensorflow/contrib/lite/delegates/flex/delegate.h
@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_
-#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_H_
#include "tensorflow/contrib/lite/c/c_api_internal.h"
-#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
+#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h"
namespace tflite {
@@ -24,12 +24,12 @@ namespace tflite {
// Delegate that can be used to extract parts of a graph that are designed to be
// executed by TensorFlow's runtime via Eager.
//
-// The interpreter must be constructed after the EagerDelegate and destructed
-// before the EagerDelegate. This delegate may be used with multiple
+// The interpreter must be constructed after the FlexDelegate and destructed
+// before the FlexDelegate. This delegate may be used with multiple
// interpreters, but it is *not* thread-safe.
//
// Usage:
-// auto delegate = EagerDelegate::Create();
+// auto delegate = FlexDelegate::Create();
// ... build interpreter ...
//
// if (delegate) {
@@ -39,21 +39,21 @@ namespace tflite {
// ... run inference ...
// ... destroy interpreter ...
// ... destroy delegate ...
-class EagerDelegate : public TfLiteDelegate {
+class FlexDelegate : public TfLiteDelegate {
public:
// Creates a delegate that supports TF ops.
//
- // If the underyling TF Eager context creation fails, returns null.
- static std::unique_ptr<EagerDelegate> Create();
+ // If the underyling TF Flex context creation fails, returns null.
+ static std::unique_ptr<FlexDelegate> Create();
- ~EagerDelegate();
+ ~FlexDelegate();
private:
- explicit EagerDelegate(std::unique_ptr<eager::DelegateData> delegate_data);
+ explicit FlexDelegate(std::unique_ptr<flex::DelegateData> delegate_data);
- std::unique_ptr<eager::DelegateData> delegate_data_;
+ std::unique_ptr<flex::DelegateData> delegate_data_;
};
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.cc b/tensorflow/contrib/lite/delegates/flex/delegate_data.cc
index 0fd5c976f8..8f985f770c 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_data.cc
+++ b/tensorflow/contrib/lite/delegates/flex/delegate_data.cc
@@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
+#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/lib/core/status.h"
namespace tflite {
-namespace eager {
+namespace flex {
tensorflow::Status DelegateData::Create(std::unique_ptr<DelegateData>* data) {
std::vector<tensorflow::Device*> devices;
@@ -43,5 +43,5 @@ DelegateData::DelegateData(tensorflow::EagerContext* eager_context)
DelegateData::~DelegateData() {}
-} // namespace eager
+} // namespace flex
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.h b/tensorflow/contrib/lite/delegates/flex/delegate_data.h
index 772d26f44e..8d75f0b0ef 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_data.h
+++ b/tensorflow/contrib/lite/delegates/flex/delegate_data.h
@@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_
-#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_
-#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h"
+#include "tensorflow/contrib/lite/delegates/flex/buffer_map.h"
#include "tensorflow/core/common_runtime/eager/context.h"
namespace tflite {
-namespace eager {
+namespace flex {
-// Data kept by the Eager delegate for the lifetime of an Interpreter.
+// Data kept by the Flex delegate for the lifetime of an Interpreter.
class DelegateData {
public:
// Create a new DelegateData, initialized with a newly-created EagerContext.
@@ -29,7 +29,7 @@ class DelegateData {
~DelegateData();
- // The EagerContext that is required for execution of Eager Ops.
+ // The EagerContext that is required for execution of Flex Ops.
tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); }
// Map from TF Lite tensor index to TensorFlow tensor for a given context.
@@ -46,7 +46,7 @@ class DelegateData {
std::unordered_map<const TfLiteContext*, BufferMap> buffer_map_;
};
-} // namespace eager
+} // namespace flex
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc b/tensorflow/contrib/lite/delegates/flex/delegate_data_test.cc
index def063309f..30b10f435a 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
+++ b/tensorflow/contrib/lite/delegates/flex/delegate_data_test.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
+#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/testing/util.h"
namespace tflite {
-namespace eager {
+namespace flex {
namespace {
TEST(DelegateDataTest, Basic) {
@@ -39,7 +39,7 @@ TEST(DelegateDataTest, Basic) {
}
} // namespace
-} // namespace eager
+} // namespace flex
} // namespace tflite
int main(int argc, char** argv) {
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/flex/delegate_test.cc
index 43ec5d53b8..1813952cef 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/flex/delegate_test.cc
@@ -12,23 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
+#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/delegates/eager/test_util.h"
+#include "tensorflow/contrib/lite/delegates/flex/test_util.h"
namespace tflite {
-namespace eager {
+namespace flex {
namespace {
using ::testing::ContainsRegex;
using ::testing::ElementsAre;
-class DelegateTest : public testing::EagerModelTest {
+class DelegateTest : public testing::FlexModelTest {
public:
DelegateTest() {
- delegate_ = EagerDelegate::Create();
+ delegate_ = FlexDelegate::Create();
interpreter_.reset(new Interpreter(&error_reporter_));
}
@@ -46,7 +46,7 @@ class DelegateTest : public testing::EagerModelTest {
}
private:
- std::unique_ptr<EagerDelegate> delegate_;
+ std::unique_ptr<FlexDelegate> delegate_;
};
TEST_F(DelegateTest, FullGraph) {
@@ -236,7 +236,7 @@ TEST_F(DelegateTest, MultipleInterpretersSameDelegate) {
}
} // namespace
-} // namespace eager
+} // namespace flex
} // namespace tflite
int main(int argc, char** argv) {
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/flex/kernel.cc
index 274c3c082a..e4f1aea990 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel.cc
+++ b/tensorflow/contrib/lite/delegates/flex/kernel.cc
@@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/kernel.h"
+#include "tensorflow/contrib/lite/delegates/flex/kernel.h"
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/builtin_ops.h"
#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/context_util.h"
-#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
-#include "tensorflow/contrib/lite/delegates/eager/util.h"
+#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h"
+#include "tensorflow/contrib/lite/delegates/flex/util.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/string.h"
#include "tensorflow/core/common_runtime/eager/context.h"
@@ -28,10 +28,10 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
-// Note: this is part of TF Lite's Eager delegation code which is to be
+// Note: this is part of TF Lite's Flex delegation code which is to be
// completed soon.
-// This is the TF Lite op that is created by the eager delegate to handle
+// This is the TF Lite op that is created by the flex delegate to handle
// execution of a supported subgraph. The usual flow is that the delegate
// informs the interpreter of supported nodes in a graph, and each supported
// subgraph is replaced with one instance of this kernel.
@@ -46,7 +46,7 @@ limitations under the License.
// corresponding TensorFlow/Eager Op.
namespace tflite {
-namespace eager {
+namespace flex {
namespace kernel {
// Controls the lifetime of tensor handles in a vector.
@@ -72,11 +72,11 @@ class VectorOfHandles {
// Executes the TensorFlow op given by 'op_name', with the attributes specified
// in 'nodedef'. Inputs and outputs are given as indices into the 'buffer_map'.
-tensorflow::Status ExecuteEagerOp(tensorflow::EagerContext* eager_context,
- BufferMap* buffer_map, const string& op_name,
- const tensorflow::NodeDef& nodedef,
- const std::vector<int>& inputs,
- const std::vector<int>& outputs) {
+tensorflow::Status ExecuteFlexOp(tensorflow::EagerContext* eager_context,
+ BufferMap* buffer_map, const string& op_name,
+ const tensorflow::NodeDef& nodedef,
+ const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
const tensorflow::AttrTypeMap* attr_types;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
tensorflow::AttrTypeMapForOp(op_name.c_str(), &attr_types),
@@ -258,13 +258,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Execute the TensorFlow Ops sequentially.
for (const auto& node_data : op_data->nodes) {
if (node_data.nodedef.op().empty()) {
- context->ReportError(context, "Invalid NodeDef in Eager op '%s'",
+ context->ReportError(context, "Invalid NodeDef in Flex op '%s'",
node_data.name.c_str());
return kTfLiteError;
}
auto status =
- ExecuteEagerOp(eager_context, buffer_map, node_data.name,
- node_data.nodedef, node_data.inputs, node_data.outputs);
+ ExecuteFlexOp(eager_context, buffer_map, node_data.name,
+ node_data.nodedef, node_data.inputs, node_data.outputs);
TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
}
@@ -295,5 +295,5 @@ TfLiteRegistration GetKernel() {
return registration;
}
-} // namespace eager
+} // namespace flex
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.h b/tensorflow/contrib/lite/delegates/flex/kernel.h
index 2478abccaa..ac9313a37b 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel.h
+++ b/tensorflow/contrib/lite/delegates/flex/kernel.h
@@ -12,23 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
-#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_KERNEL_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_KERNEL_H_
#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
-namespace eager {
+namespace flex {
// Return the registration object used to initialize and execute ops that will
// be delegated to TensorFlow's Eager runtime. This TF Lite op is created by
-// the eager delegate to handle execution of a supported subgraph. The usual
+// the flex delegate to handle execution of a supported subgraph. The usual
// flow is that the delegate informs the interpreter of supported nodes in a
// graph, and each supported subgraph is replaced with one instance of this
// kernel.
TfLiteRegistration GetKernel();
-} // namespace eager
+} // namespace flex
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_KERNEL_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel_test.cc b/tensorflow/contrib/lite/delegates/flex/kernel_test.cc
index 66f2226626..94a6f8b61a 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel_test.cc
+++ b/tensorflow/contrib/lite/delegates/flex/kernel_test.cc
@@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/kernel.h"
+#include "tensorflow/contrib/lite/delegates/flex/kernel.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
-#include "tensorflow/contrib/lite/delegates/eager/test_util.h"
+#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h"
+#include "tensorflow/contrib/lite/delegates/flex/test_util.h"
namespace tflite {
-namespace eager {
+namespace flex {
namespace {
using ::testing::ContainsRegex;
@@ -31,12 +31,12 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteDelegate* delegate,
TfLiteIntArray* size_and_nodes =
ConvertVectorToTfLiteIntArray(supported_nodes);
TF_LITE_ENSURE_STATUS(context->ReplaceSubgraphsWithDelegateKernels(
- context, eager::GetKernel(), size_and_nodes, delegate));
+ context, flex::GetKernel(), size_and_nodes, delegate));
TfLiteIntArrayFree(size_and_nodes);
return kTfLiteOk;
}
-class KernelTest : public testing::EagerModelTest {
+class KernelTest : public testing::FlexModelTest {
public:
KernelTest() {
CHECK(DelegateData::Create(&delegate_data_).ok());
@@ -167,7 +167,7 @@ TEST_F(KernelTest, WrongSetOfNodes) {
ASSERT_FALSE(Invoke());
ASSERT_THAT(error_reporter().error_messages(),
- ContainsRegex("Invalid NodeDef in Eager op"));
+ ContainsRegex("Invalid NodeDef in Flex op"));
}
TEST_F(KernelTest, MixedGraph) {
@@ -220,7 +220,7 @@ TEST_F(KernelTest, SplitGraph) {
}
} // namespace
-} // namespace eager
+} // namespace flex
} // namespace tflite
int main(int argc, char** argv) {
diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.cc b/tensorflow/contrib/lite/delegates/flex/test_util.cc
index 8584999ace..69c336a01a 100644
--- a/tensorflow/contrib/lite/delegates/eager/test_util.cc
+++ b/tensorflow/contrib/lite/delegates/flex/test_util.cc
@@ -13,25 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/test_util.h"
+#include "tensorflow/contrib/lite/delegates/flex/test_util.h"
#include "absl/memory/memory.h"
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/string.h"
namespace tflite {
-namespace eager {
+namespace flex {
namespace testing {
-bool EagerModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
+bool FlexModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
-void EagerModelTest::SetShape(int tensor_index,
- const std::vector<int>& values) {
+void FlexModelTest::SetShape(int tensor_index, const std::vector<int>& values) {
ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk);
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
}
-std::vector<int> EagerModelTest::GetShape(int tensor_index) {
+std::vector<int> FlexModelTest::GetShape(int tensor_index) {
std::vector<int> result;
auto* dims = interpreter_->tensor(tensor_index)->dims;
result.reserve(dims->size);
@@ -41,13 +40,13 @@ std::vector<int> EagerModelTest::GetShape(int tensor_index) {
return result;
}
-TfLiteType EagerModelTest::GetType(int tensor_index) {
+TfLiteType FlexModelTest::GetType(int tensor_index) {
return interpreter_->tensor(tensor_index)->type;
}
-void EagerModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
- const std::vector<int>& outputs,
- TfLiteType type, const std::vector<int>& dims) {
+void FlexModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
+ const std::vector<int>& outputs, TfLiteType type,
+ const std::vector<int>& dims) {
interpreter_->AddTensors(num_tensors);
for (int i = 0; i < num_tensors; ++i) {
TfLiteQuantizationParams quant;
@@ -66,8 +65,8 @@ void EagerModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk);
}
-void EagerModelTest::AddTfLiteMulOp(const std::vector<int>& inputs,
- const std::vector<int>& outputs) {
+void FlexModelTest::AddTfLiteMulOp(const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
reg.builtin_code = BuiltinOperator_MUL;
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
@@ -90,8 +89,8 @@ void EagerModelTest::AddTfLiteMulOp(const std::vector<int>& inputs,
kTfLiteOk);
}
-void EagerModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
- const std::vector<int>& outputs) {
+void FlexModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
auto attr = [](const string& key, const string& value) {
return " attr{ key: '" + key + "' value {" + value + "}}";
};
@@ -107,28 +106,28 @@ void EagerModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
if (op == kUnpack) {
string attributes =
type_attribute + attr("num", "i: 2") + attr("axis", "i: 0");
- AddTfOp("EagerUnpack", "Unpack", attributes, inputs, outputs);
+ AddTfOp("FlexUnpack", "Unpack", attributes, inputs, outputs);
} else if (op == kIdentity) {
string attributes = type_attribute;
- AddTfOp("EagerIdentity", "Identity", attributes, inputs, outputs);
+ AddTfOp("FlexIdentity", "Identity", attributes, inputs, outputs);
} else if (op == kAdd) {
string attributes = type_attribute;
- AddTfOp("EagerAdd", "Add", attributes, inputs, outputs);
+ AddTfOp("FlexAdd", "Add", attributes, inputs, outputs);
} else if (op == kMul) {
string attributes = type_attribute;
- AddTfOp("EagerMul", "Mul", attributes, inputs, outputs);
+ AddTfOp("FlexMul", "Mul", attributes, inputs, outputs);
} else if (op == kNonExistent) {
AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);
} else if (op == kIncompatibleNodeDef) {
// "Cast" op is created without attributes - making it incompatible.
- AddTfOp("EagerCast", "Cast", "", inputs, outputs);
+ AddTfOp("FlexCast", "Cast", "", inputs, outputs);
}
}
-void EagerModelTest::AddTfOp(const char* tflite_name, const string& tf_name,
- const string& nodedef_str,
- const std::vector<int>& inputs,
- const std::vector<int>& outputs) {
+void FlexModelTest::AddTfOp(const char* tflite_name, const string& tf_name,
+ const string& nodedef_str,
+ const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
reg.builtin_code = BuiltinOperator_CUSTOM;
reg.custom_name = tflite_name;
@@ -154,5 +153,5 @@ void EagerModelTest::AddTfOp(const char* tflite_name, const string& tf_name,
}
} // namespace testing
-} // namespace eager
+} // namespace flex
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.h b/tensorflow/contrib/lite/delegates/flex/test_util.h
index 816db41931..a8c81b90a3 100644
--- a/tensorflow/contrib/lite/delegates/eager/test_util.h
+++ b/tensorflow/contrib/lite/delegates/flex/test_util.h
@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_
-#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_TEST_UTIL_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_TEST_UTIL_H_
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
namespace tflite {
-namespace eager {
+namespace flex {
namespace testing {
enum TfOpType {
@@ -35,12 +35,12 @@ enum TfOpType {
};
// This class creates models with TF and TFLite ops. In order to use this class
-// to test the Eager delegate, implement a function that calls
+// to test the Flex delegate, implement a function that calls
// interpreter->ModifyGraphWithDelegate.
-class EagerModelTest : public ::testing::Test {
+class FlexModelTest : public ::testing::Test {
public:
- EagerModelTest() {}
- ~EagerModelTest() {}
+ FlexModelTest() {}
+ ~FlexModelTest() {}
bool Invoke();
@@ -104,7 +104,7 @@ class EagerModelTest : public ::testing::Test {
private:
// Helper method to add a TensorFlow op. tflite_names needs to start with
- // "Eager" in order to work with the Eager delegate.
+ // "Flex" in order to work with the Flex delegate.
void AddTfOp(const char* tflite_name, const string& tf_name,
const string& nodedef_str, const std::vector<int>& inputs,
const std::vector<int>& outputs);
@@ -113,7 +113,7 @@ class EagerModelTest : public ::testing::Test {
};
} // namespace testing
-} // namespace eager
+} // namespace flex
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_TEST_UTIL_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/flex/util.cc
index 051246bf86..829bc388bf 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.cc
+++ b/tensorflow/contrib/lite/delegates/flex/util.cc
@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/util.h"
+#include "tensorflow/contrib/lite/delegates/flex/util.h"
namespace tflite {
-namespace eager {
+namespace flex {
TfLiteStatus ConvertStatus(TfLiteContext* context,
const tensorflow::Status& status) {
@@ -100,5 +100,5 @@ TfLiteType GetTensorFlowLiteType(TF_DataType type) {
}
}
-} // namespace eager
+} // namespace flex
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/flex/util.h
index 930cb99cb9..7f910e7316 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.h
+++ b/tensorflow/contrib/lite/delegates/flex/util.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_
-#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_UTIL_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_UTIL_H_
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/contrib/lite/c/c_api_internal.h"
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
namespace tflite {
-namespace eager {
+namespace flex {
// Converts a tensorflow:Status into a TfLiteStatus. If the original status
// represented an error, reports it using the given 'context'.
@@ -41,7 +41,7 @@ TF_DataType GetTensorFlowDataType(TfLiteType type);
// Returns the TfLiteType that corresponds to the given TF C API Data type.
TfLiteType GetTensorFlowLiteType(TF_DataType);
-} // namespace eager
+} // namespace flex
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_UTIL_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/util_test.cc b/tensorflow/contrib/lite/delegates/flex/util_test.cc
index aebc91149c..5f049e7b0a 100644
--- a/tensorflow/contrib/lite/delegates/eager/util_test.cc
+++ b/tensorflow/contrib/lite/delegates/flex/util_test.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/util.h"
+#include "tensorflow/contrib/lite/delegates/flex/util.h"
#include <cstdarg>
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/testing/util.h"
namespace tflite {
-namespace eager {
+namespace flex {
namespace {
using tensorflow::DT_FLOAT;
@@ -132,7 +132,7 @@ TEST(UtilTest, TypeConversionsFromTensorFlow) {
}
} // namespace
-} // namespace eager
+} // namespace flex
} // namespace tflite
int main(int argc, char** argv) {
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
index c6587b3d3f..d85e576284 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
@@ -518,7 +518,7 @@ class NNAPIDelegateKernel {
}
break;
case kTfLiteBuiltinReshape:
- if (version == 1) {
+ if (version == 1 && node->inputs->size == 2) {
return [](const NNAPIOpMappingArgs& mapping_args)
-> ANeuralNetworksOperationType {
return ANEURALNETWORKS_RESHAPE;
diff --git a/tensorflow/contrib/lite/examples/android/app/README.md b/tensorflow/contrib/lite/examples/android/app/README.md
index cbdeeac879..7347147f99 100644
--- a/tensorflow/contrib/lite/examples/android/app/README.md
+++ b/tensorflow/contrib/lite/examples/android/app/README.md
@@ -1,8 +1,43 @@
# TF Lite Android App Example
+A simple Android example that demonstrates image classification and object
+detection using the camera, as well as speech recognition using the microphone.
+
+## Building in Android Studio with TensorFlow Lite AAR from JCenter.
+The build.gradle is configured to use TensorFlow Lite's nightly build.
+
+If you see a build error related to compatibility with Tensorflow Lite's Java
+API (example: method X is undefined for type Interpreter), there has likely been
+a backwards compatible change to the API. You will need to pull new app code
+that's compatible with the nightly build and may need to first wait a few days
+for our external and internal code to merge.
+
## Building from Source with Bazel
-1. Install [Bazel](https://docs.bazel.build/versions/master/install.html), the Android NDK and SDK. The recommended versions are specified on this [webpage](https://www.tensorflow.org/mobile/tflite/demo_android#build_tensorflow_lite_and_the_demo_app_from_source).
+1. Follow the [Bazel steps for the TF Demo App](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#bazel):
+
+ 1. [Install Bazel and Android Prerequisites](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites).
+ It's easiest with Android Studio.
+
+ - You'll need at least SDK version 23.
+ - Make sure to install the latest version of Bazel. Some distributions
+ ship with Bazel 0.5.4, which is too old.
+ - Bazel requires Android Build Tools `26.0.1` or higher.
+ - You also need to install the Android Support Repository, available
+ through Android Studio under `Android SDK Manager -> SDK Tools ->
+ Android Support Repository`.
+
+ 2. [Edit your `WORKSPACE`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#edit-workspace)
+ to add SDK and NDK targets.
+
+ NOTE: As long as you have the SDK and NDK installed, the `./configure`
+ script will create these rules for you. Answer "Yes" when the script asks
+ to automatically configure the `./WORKSPACE`.
+
+ - Make sure the `api_level` in `WORKSPACE` is set to an SDK version that
+ you have installed.
+ - By default, Android Studio will install the SDK to `~/Android/Sdk` and
+ the NDK to `~/Android/Sdk/ndk-bundle`.
2. Build this demo app with Bazel. The demo needs C++11. We configure the fat_apk_cpu flag to package support for 4 hardware variants. You may replace it with --config=android_arm64 on a 64-bit device and --config=android_arm for 32-bit device:
diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h
index 6fdcf78b69..21ad39a6bf 100644
--- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h
+++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h
@@ -80,8 +80,7 @@ void resize(T* out, uint8_t* in, int image_height, int image_width,
interpreter->Invoke();
auto output = interpreter->typed_tensor<float>(2);
- auto output_number_of_pixels =
- wanted_height * wanted_height * wanted_channels;
+ auto output_number_of_pixels = wanted_height * wanted_width * wanted_channels;
for (int i = 0; i < output_number_of_pixels; i++) {
if (s->input_floating)
diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/contrib/lite/experimental/c/BUILD
index ea4a543252..52e71619de 100644
--- a/tensorflow/contrib/lite/experimental/c/BUILD
+++ b/tensorflow/contrib/lite/experimental/c/BUILD
@@ -1,5 +1,12 @@
package(default_visibility = ["//visibility:private"])
+package_group(
+ name = "experimental",
+ packages = [
+ "//tensorflow/contrib/lite/experimental/...",
+ ],
+)
+
licenses(["notice"]) # Apache 2.0
load(
@@ -51,6 +58,9 @@ cc_library(
srcs = ["c_api.cc"],
hdrs = ["c_api.h"],
copts = tflite_copts(),
+ visibility = [
+ ":experimental",
+ ],
deps = [
":c_api_internal",
"//tensorflow/contrib/lite:context",
@@ -68,6 +78,7 @@ cc_library(
deps = [
":c_api",
":c_api_internal",
+ "//tensorflow/contrib/lite:kernel_api",
],
)
@@ -93,6 +104,7 @@ cc_test(
deps = [
":c_api",
":c_api_experimental",
+ "//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
],
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/contrib/lite/experimental/c/c_api.cc
index c589cf71ea..9c29f9d8b9 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/error_reporter.h"
#include "tensorflow/contrib/lite/experimental/c/c_api_internal.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
@@ -26,6 +27,26 @@ limitations under the License.
extern "C" {
#endif // __cplusplus
+namespace {
+class CallbackErrorReporter : public tflite::ErrorReporter {
+ public:
+ using ErrorCallback = void (*)(void* user_data, const char* format,
+ va_list args);
+
+ CallbackErrorReporter(ErrorCallback callback, void* user_data)
+ : callback_(callback), user_data_(user_data) {}
+
+ int Report(const char* format, va_list args) override {
+ callback_(user_data_, format, args);
+ return 0;
+ }
+
+ private:
+ ErrorCallback callback_;
+ void* user_data_;
+};
+} // namespace
+
// LINT.IfChange
TFL_Model* TFL_NewModel(const void* model_data, size_t model_size) {
@@ -56,14 +77,38 @@ void TFL_InterpreterOptionsSetNumThreads(TFL_InterpreterOptions* options,
options->num_threads = num_threads;
}
+TFL_CAPI_EXPORT extern void TFL_InterpreterOptionsSetErrorReporter(
+ TFL_InterpreterOptions* options,
+ void (*reporter)(void* user_data, const char* format, va_list args),
+ void* user_data) {
+ options->error_reporter = reporter;
+ options->error_reporter_user_data = user_data;
+}
+
TFL_Interpreter* TFL_NewInterpreter(
const TFL_Model* model, const TFL_InterpreterOptions* optional_options) {
if (!model || !model->impl) {
return nullptr;
}
+ std::unique_ptr<tflite::ErrorReporter> optional_error_reporter;
+ if (optional_options && optional_options->error_reporter != nullptr) {
+ optional_error_reporter.reset(
+ new CallbackErrorReporter(optional_options->error_reporter,
+ optional_options->error_reporter_user_data));
+ }
+
+ // TODO(b/111881878): Allow use of C API without pulling in all builtin ops.
tflite::ops::builtin::BuiltinOpResolver resolver;
- tflite::InterpreterBuilder builder(*model->impl, resolver);
+ if (optional_options) {
+ resolver.AddAll(optional_options->op_resolver);
+ }
+ tflite::ErrorReporter* error_reporter = optional_error_reporter
+ ? optional_error_reporter.get()
+ : tflite::DefaultErrorReporter();
+ tflite::InterpreterBuilder builder(model->impl->GetModel(), resolver,
+ error_reporter);
+
std::unique_ptr<tflite::Interpreter> interpreter;
if (builder(&interpreter) != kTfLiteOk) {
return nullptr;
@@ -76,7 +121,8 @@ TFL_Interpreter* TFL_NewInterpreter(
}
}
- return new TFL_Interpreter{model->impl, std::move(interpreter)};
+ return new TFL_Interpreter{model->impl, std::move(optional_error_reporter),
+ std::move(interpreter)};
}
void TFL_DeleteInterpreter(TFL_Interpreter* interpreter) { delete interpreter; }
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.h b/tensorflow/contrib/lite/experimental/c/c_api.h
index b429e76870..f52ab8f9ed 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api.h
@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_H_
#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_H_
+#include <stdarg.h>
#include <stdint.h>
// Eventually the various C APIs defined in context.h will be migrated into
@@ -52,8 +53,9 @@ limitations under the License.
extern "C" {
#endif // __cplusplus
-typedef TfLiteTensor TFL_Tensor;
+typedef TfLiteRegistration TFL_Registration;
typedef TfLiteStatus TFL_Status;
+typedef TfLiteTensor TFL_Tensor;
typedef TfLiteType TFL_Type;
// --------------------------------------------------------------------------
@@ -85,6 +87,17 @@ TFL_CAPI_EXPORT extern void TFL_DeleteInterpreterOptions(
TFL_CAPI_EXPORT extern void TFL_InterpreterOptionsSetNumThreads(
TFL_InterpreterOptions* options, int32_t num_threads);
+// Sets a custom error reporter for interpreter execution.
+//
+// * `reporter` takes the provided `user_data` object, as well as a C-style
+// format string and arg list (see also vprintf).
+// * `user_data` is optional. If provided, it is owned by the client and must
+// remain valid for the duration of the interpreter lifetime.
+TFL_CAPI_EXPORT extern void TFL_InterpreterOptionsSetErrorReporter(
+ TFL_InterpreterOptions* options,
+ void (*reporter)(void* user_data, const char* format, va_list args),
+ void* user_data);
+
// --------------------------------------------------------------------------
// TFL_Interpreter provides inference from a provided model.
typedef struct TFL_Interpreter TFL_Interpreter;
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc
index c4dbc55cbf..29f8701f53 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc
@@ -21,9 +21,24 @@ limitations under the License.
extern "C" {
#endif // __cplusplus
-TFL_Status TFL_InterpreterResetVariableTensorsToZero(
- TFL_Interpreter* interpreter) {
- return interpreter->impl->ResetVariableTensorsToZero();
+TFL_Status TFL_InterpreterResetVariableTensors(TFL_Interpreter* interpreter) {
+ return interpreter->impl->ResetVariableTensors();
+}
+
+void TFL_InterpreterOptionsAddBuiltinOp(TFL_InterpreterOptions* options,
+ TFL_BuiltinOperator op,
+ const TFL_Registration* registration,
+ int32_t min_version,
+ int32_t max_version) {
+ options->op_resolver.AddBuiltin(static_cast<tflite::BuiltinOperator>(op),
+ registration, min_version, max_version);
+}
+
+void TFL_InterpreterOptionsAddCustomOp(TFL_InterpreterOptions* options,
+ const char* name,
+ const TFL_Registration* registration,
+ int min_version, int max_version) {
+ options->op_resolver.AddCustom(name, registration, min_version, max_version);
}
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h
index b0ac258dcf..fca5d92f77 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h
@@ -15,16 +15,41 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_
#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_
+#include "tensorflow/contrib/lite/builtin_ops.h"
#include "tensorflow/contrib/lite/experimental/c/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
+typedef TfLiteBuiltinOperator TFL_BuiltinOperator;
+
// Resets all variable tensors to zero.
-TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterResetVariableTensorsToZero(
+TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterResetVariableTensors(
TFL_Interpreter* interpreter);
+// Adds an op registration for a builtin operator.
+//
+// NOTE: The interpreter will make a copy of `registration` internally, so the
+// caller should ensure that its contents (function pointers, etc...) remain
+// valid for the duration of the interpreter's lifetime. A common practice is
+// making the provided TFL_Registration instance static.
+void TFL_InterpreterOptionsAddBuiltinOp(TFL_InterpreterOptions* options,
+ TFL_BuiltinOperator op,
+ const TFL_Registration* registration,
+ int min_version, int max_version);
+
+// Adds an op registration for a custom operator.
+//
+// NOTE: The interpreter will make a copy of `registration` internally, so the
+// caller should ensure that its contents (function pointers, etc...) remain
+// valid for the duration of the interpreter's lifetime. A common practice is
+// making the provided TFL_Registration instance static.
+void TFL_InterpreterOptionsAddCustomOp(TFL_InterpreterOptions* options,
+ const char* name,
+ const TFL_Registration* registration,
+ int min_version, int max_version);
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc
index db6e5251de..1b1bedb754 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc
@@ -16,25 +16,40 @@ limitations under the License.
#include "tensorflow/contrib/lite/experimental/c/c_api_experimental.h"
#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/builtin_ops.h"
#include "tensorflow/contrib/lite/experimental/c/c_api.h"
#include "tensorflow/contrib/lite/testing/util.h"
namespace {
+TfLiteRegistration* GetDummyRegistration() {
+ static TfLiteRegistration registration = {
+ .init = nullptr,
+ .free = nullptr,
+ .prepare = nullptr,
+ .invoke = [](TfLiteContext*, TfLiteNode*) { return kTfLiteOk; },
+ };
+ return &registration;
+}
+
TEST(CApiExperimentalSimple, Smoke) {
TFL_Model* model = TFL_NewModelFromFile(
"tensorflow/contrib/lite/testdata/add.bin");
ASSERT_NE(model, nullptr);
- TFL_Interpreter* interpreter =
- TFL_NewInterpreter(model, /*optional_options=*/nullptr);
+ TFL_InterpreterOptions* options = TFL_NewInterpreterOptions();
+ TFL_InterpreterOptionsAddBuiltinOp(options, kTfLiteBuiltinAdd,
+ GetDummyRegistration(), 1, 1);
+
+ TFL_Interpreter* interpreter = TFL_NewInterpreter(model, options);
ASSERT_NE(interpreter, nullptr);
ASSERT_EQ(TFL_InterpreterAllocateTensors(interpreter), kTfLiteOk);
+ EXPECT_EQ(TFL_InterpreterResetVariableTensors(interpreter), kTfLiteOk);
+ EXPECT_EQ(TFL_InterpreterInvoke(interpreter), kTfLiteOk);
- EXPECT_EQ(TFL_InterpreterResetVariableTensorsToZero(interpreter), kTfLiteOk);
-
- TFL_DeleteModel(model);
TFL_DeleteInterpreter(interpreter);
+ TFL_DeleteInterpreterOptions(options);
+ TFL_DeleteModel(model);
}
} // namespace
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_internal.h b/tensorflow/contrib/lite/experimental/c/c_api_internal.h
index 60c2e4e2cd..da3af3cad4 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_internal.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api_internal.h
@@ -19,9 +19,13 @@ limitations under the License.
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/op_resolver.h"
// Internal structures used by the C API. These are likely to change and should
// not be depended on.
+//
+// NOTE: This header does not follow C conventions and does not define a C API.
+// It is effectively an (internal) implementation detail of the C API.
struct TFL_Model {
// Sharing is safe as FlatBufferModel is const.
@@ -33,12 +37,24 @@ struct TFL_InterpreterOptions {
kDefaultNumThreads = -1,
};
int num_threads = kDefaultNumThreads;
+
+ tflite::MutableOpResolver op_resolver;
+
+ void (*error_reporter)(void* user_data, const char* format,
+ va_list args) = nullptr;
+ void* error_reporter_user_data = nullptr;
};
struct TFL_Interpreter {
// Taking a reference to the (const) model data avoids lifetime-related issues
// and complexity with the TFL_Model's existence.
std::shared_ptr<const tflite::FlatBufferModel> model;
+
+ // The interpreter does not take ownership of the provided ErrorReporter
+ // instance, so we ensure its validity here. Note that the interpreter may use
+ // the reporter in its destructor, so it should be declared first.
+ std::unique_ptr<tflite::ErrorReporter> optional_error_reporter;
+
std::unique_ptr<tflite::Interpreter> impl;
};
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_test.cc
index 649dac8d1a..48a3714ec3 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_test.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api_test.cc
@@ -85,6 +85,37 @@ TEST(CApiSimple, Smoke) {
TFL_DeleteInterpreter(interpreter);
}
+TEST(CApiSimple, ErrorReporter) {
+ TFL_Model* model = TFL_NewModelFromFile(
+ "tensorflow/contrib/lite/testdata/add.bin");
+ TFL_InterpreterOptions* options = TFL_NewInterpreterOptions();
+
+ // Install a custom error reporter into the interpreter by way of options.
+ tflite::TestErrorReporter reporter;
+ TFL_InterpreterOptionsSetErrorReporter(
+ options,
+ [](void* user_data, const char* format, va_list args) {
+ reinterpret_cast<tflite::TestErrorReporter*>(user_data)->Report(format,
+ args);
+ },
+ &reporter);
+ TFL_Interpreter* interpreter = TFL_NewInterpreter(model, options);
+
+ // The options/model can be deleted immediately after interpreter creation.
+ TFL_DeleteInterpreterOptions(options);
+ TFL_DeleteModel(model);
+
+ // Invoke the interpreter before tensor allocation.
+ EXPECT_EQ(TFL_InterpreterInvoke(interpreter), kTfLiteError);
+
+ // The error should propagate to the custom error reporter.
+ EXPECT_EQ(reporter.error_messages(),
+ "Invoke called on model that is not ready.");
+ EXPECT_EQ(reporter.num_calls(), 1);
+
+ TFL_DeleteInterpreter(interpreter);
+}
+
} // namespace
int main(int argc, char** argv) {
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
index 8442c4d46c..b1ebe4a804 100644
--- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
index aa42b495bd..942dbbbeae 100644
--- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc
index e6d5a776b3..b35c6e0655 100644
--- a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc
+++ b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include <iostream>
#include <unordered_map>
#include <unordered_set>
-#include "flatbuffers/minireflect.h" // flatbuffers
+#include "flatbuffers/minireflect.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/g3doc/_book.yaml b/tensorflow/contrib/lite/g3doc/_book.yaml
index 1dffe30790..de6914e536 100644
--- a/tensorflow/contrib/lite/g3doc/_book.yaml
+++ b/tensorflow/contrib/lite/g3doc/_book.yaml
@@ -5,7 +5,7 @@ upper_tabs:
# Dropdown menu
- name: Ecosystem
path: /ecosystem
- is_default: True
+ is_default: true
menu:
- include: /ecosystem/_menu_toc.yaml
lower_tabs:
@@ -14,46 +14,50 @@ upper_tabs:
- name: Guide
contents:
- title: Overview
- path: /mobile/overview
- - title: Developer Guide
- path: /mobile/devguide
- - title: Android Demo App
- path: /mobile/demo_android
- - title: iOS Demo App
- path: /mobile/demo_ios
+ path: /lite/overview
+ - title: Developer guide
+ path: /lite/devguide
+ - title: Android demo app
+ path: /lite/demo_android
+ - title: iOS demo app
+ path: /lite/demo_ios
- title: Performance
- path: /mobile/performance
- - break: True
+ path: /lite/performance
+ - break: true
- title: TensorFlow Lite APIs
- path: /mobile/apis
+ path: /lite/apis
- title: Custom operators
- path: /mobile/custom_operators
- - title: TensorFlow Lite Ops Versioning
- path: /mobile/ops_versioning
- - title: TensorFlow Lite Compatibility Guide
- path: /mobile/tf_ops_compatibility
- - title: List of Hosted Models
- path: /mobile/models
+ path: /lite/custom_operators
+ - title: TensorFlow Lite ops versioning
+ path: /lite/ops_versioning
+ - title: TensorFlow Lite compatibility guide
+ path: /lite/tf_ops_compatibility
+ - title: List of hosted models
+ path: /lite/models
- title: TensorFlow Lite for iOS
- path: /mobile/ios
+ path: /lite/ios
- title: TensorFlow Lite for Raspberry Pi
- path: /mobile/rpi
+ path: /lite/rpi
- - heading: TF Mobile
+ - title: TF Mobile
+ style: accordion
status: deprecated
- - title: Overview
- path: /mobile/tfmobile/
- - title: Building TensorFlow on Android
- path: /mobile/tfmobile/android_build
- - title: Building TensorFlow on IOS
- path: /mobile/tfmobile/ios_build
- - title: Integrating TensorFlow libraries
- path: /mobile/tfmobile/linking_libs
- - title: Preparing models for mobile deployment
- path: /mobile/tfmobile/prepare_models
- - title: Optimizing for mobile
- path: /mobile/tfmobile/optimizing
+ section:
+ - title: Overview
+ path: /lite/tfmobile/
+ - title: Building TensorFlow on Android
+ path: /lite/tfmobile/android_build
+ - title: Building TensorFlow on IOS
+ path: /lite/tfmobile/ios_build
+ - title: Integrating TensorFlow libraries
+ path: /lite/tfmobile/linking_libs
+ - title: Preparing models for mobile deployment
+ path: /lite/tfmobile/prepare_models
+ - title: Optimizing for mobile
+ path: /lite/tfmobile/optimizing
- name: API
+ skip_translation: true
contents:
- - include: /mobile/api_docs/python/_toc.yaml
+ - title: API
+ path: /api_docs/python/tf/contrib/lite
diff --git a/tensorflow/contrib/lite/g3doc/_index.yaml b/tensorflow/contrib/lite/g3doc/_index.yaml
index b3f21e21ac..bc66cc5dc1 100644
--- a/tensorflow/contrib/lite/g3doc/_index.yaml
+++ b/tensorflow/contrib/lite/g3doc/_index.yaml
@@ -1,60 +1,209 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
+project_path: /lite/_project.yaml
+book_path: /lite/_book.yaml
description: <!--no description-->
landing_page:
+ custom_css_path: /site-assets/css/style.css
rows:
- - heading: TensorFlow Lite is a lightweight solution for mobile and embedded devices.
+ - heading: TensorFlow Lite is for mobile and embedded devices.
+ description: >
+ <p style="max-width: 75%;">
+ TensorFlow Lite is the official solution for running machine learning
+ models on mobile and embedded devices. It enables on&#8209;device machine
+ learning inference with low latency and a small binary size on Android,
+ iOS, and other operating systems.
+ </p>
+ <style>
+ .tfo-landing-row-heading {
+ padding-top: 0 !important;
+ }
+ .tfo-landing-row-heading h2 {
+ margin-top: 0 !important;
+ }
+ .tfo-landing-row-heading-list ol, .tfo-landing-row-heading-list ul {
+ margin-top: 0;
+ }
+ </style>
+
+ - classname: tfo-landing-row-heading tfo-landing-row-heading-list
+ heading: Many benefits
+ description: >
+ On-device ML inference is difficult because of the many constraints—TensorFlow Lite can solve these:
items:
- - classname: devsite-landing-row-50
- description: >
- TensorFlow Lite is TensorFlow’s lightweight solution for mobile and
- embedded devices. It enables on-device machine learning inference with
- low latency and a small binary size. TensorFlow Lite also supports
- hardware acceleration with the
- <a href='https://developer.android.com/ndk/guides/neuralnetworks/index.html'>Android Neural Networks API</a>.
- list:
- - heading: Key point 1
+ - list:
+ - heading: Performance
+ description: >
+ TF Lite is fast with no noticeable accuracy loss—see the <a href="./performance">metrics</a>.
+ icon:
+ icon_name: lens
+ foreground: theme
+ - heading: Portability
description: >
- [high-level overview]
+ <a href="https://developer.android.com/ndk/guides/neuralnetworks/" class="external">Android</a>,
+ iOS, and more specialized IoT devices.
icon:
- icon_name: chevron_right
+ icon_name: lens
foreground: theme
- background: grey
- - heading: Key point 2
+ - list:
+ - heading: Low latency
description: >
- [high-level overview]
+ Optimized float- and fixed-point CPU kernels, op&#8209;fusing, and more.
icon:
- icon_name: chevron_right
+ icon_name: lens
foreground: theme
- background: grey
- - heading: Key point 3
+ - heading: Acceleration
description: >
- [high-level overview]
+ Integration with GPU and internal/external accelerators.
icon:
- icon_name: chevron_right
+ icon_name: lens
foreground: theme
- background: grey
- code_block: |
- <pre class = "prettyprint">
- $ toco --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
- --input_format=TENSORFLOW_GRAPHDEF \
- --output_format=TFLITE \
- --output_file=/tmp/mobilenet_v1_1.0_224.tflite \
- --inference_type=FLOAT \
- --input_type=FLOAT \
- --input_arrays=input \
- --output_arrays=MobilenetV1/Predictions/Reshape_1 \
- --input_shapes=1,224,224,3
- </pre>
+ - list:
+ - heading: Small model size
+ description: >
+ Controlled dependencies, <a href="https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3" class="external">quantization</a>,
+ and op&nbsp;registration.
+ icon:
+ icon_name: lens
+ foreground: theme
+ - heading: Tooling
+ description: >
+ Conversion, compression, benchmarking, power-consumption, and more.
+ icon:
+ icon_name: lens
+ foreground: theme
+
+ - classname: devsite-landing-row-logos tfo-landing-row-heading
+ heading: Companies using TensorFlow Lite
+ items:
+ - custom_image:
+ path: ./images/landing-page/photos_logo.png
+ path: https://www.photos.google.com
+ - custom_image:
+ path: ./images/landing-page/gboard_logo.png
+ path: https://play.google.com/store/apps/details?id=com.google.android.inputmethod.latin&hl=en_US
+ - custom_image:
+ path: ./images/landing-page/gmail_logo.png
+ path: https://www.google.com/gmail/
+ - custom_image:
+ path: ./images/landing-page/assistant_logo.png
+ path: https://assistant.google.com/
+
+ - classname: devsite-landing-row-logos
+ items:
+ - custom_image:
+ path: ./images/landing-page/vsco_logo.png
+ path: https://vsco.co
+ - custom_image:
+ path: ./images/landing-page/shazam_logo.png
+ path: https://www.shazam.com/
+ - custom_image:
+ path: ./images/landing-page/nest_logo.png
+ path: https://nest.com/
+ - custom_image:
+ path: ./images/landing-page/loseit_logo.png
+ path: https://www.loseit.com/
+
+ - classname: devsite-landing-row-no-image-background devsite-landing-row-67
+ background: grey
+ items:
+ - description: >
+ <em>“TensorFlow Lite helped us introduce machine learning and AI into our
+ app in an easy and streamlined way. We could reduce the size of our
+ models while keeping the accuracy high. This helped us create an amazing
+ fishing experience for our users by allowing them to identify any fish
+ species with just a photo.”</em>
+ image_path: ./images/landing-page/fishbrain_logo_big.png
+
+ - heading: How it works
+ items:
+ - heading: Build
+ icon:
+ icon_name: build
+ description: >
+ Build a new model or retrain an existing one, such as using transfer learning.
+ buttons:
+ - label: Read the developer guide
+ path: /lite/devguide
+ classname: button button-primary tfo-button-primary
+ - heading: Convert
+ icon:
+ icon_name: autorenew
+ description: >
+ Convert a TensorFlow model into a compressed flat buffer with the
+ TensorFlow Lite Optimizing Converter (TOCO).
+ buttons:
+ - label: Read the TOCO guide
+ path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/python_api.md
+ classname: button button-primary tfo-button-primary
+ - heading: Deploy
+ icon:
+ icon_name: bolt
+ description: >
+ Take the compressed <code>.tflite</code> file and load it into a mobile
+ or embedded device.<br/>
+ See the <a href="#build-your-first-tensorflow-lite-app">tutorials below</a> to build an app.
+
+ - heading: Build your first TensorFlow Lite app
+ background: grey
+ items:
+ - classname: tfo-landing-row-item-inset-white
+ heading: Get started
+ description: >
+ <ul>
+ <li>Beginner: <a href="https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/" class="external">TensorFlow for Poets</a></li>
+ <li>Beginner: <a href="https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2-tflite/" class="external">TensorFlow for Poets 2: Android</a></li>
+ <li>Beginner: <a href="https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2-ios/" class="external">TensorFlow for Poets 2: iOS </a></li>
+ <li>Intermediate: <a href="https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193" class="external">Object detection tutorial</a>
+ </ul>
+ - classname: tfo-landing-row-item-inset-white
+ heading: Share your TensorFlow Lite story
+ description: >
+ We love to hear what you're working on—it may even get highlighted on
+ our social media! <a href="https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss" class="external">Tell us</a>.
+
+ - classname: devsite-landing-row-no-image-background devsite-landing-row-67
+ items:
+ - description: >
+ <p>
+ <em>“The release of TensorFlow Lite has allowed us to deploy an engaging
+ real-time experience to our users that eliminates the requirement
+ for a data connection. TensorFlow Lite’s ability to compress and
+ optimize the TensorFlow graph for mobile deployment has been
+ transformative in expanding the capabilities of Snap It.</em>
+ </p>
+ <p>
+ <em>Through TensorFlow Lite, our users can now enjoy a state of the
+ art, computer-vision-based food logging experience without worrying
+ about signal strength. We look forward to future collaborations
+ with the TensorFlow Lite team.”</em>
+ </p>
+ image_path: ./images/landing-page/loseit_logo_big.png
- classname: devsite-landing-row-cards
+ background: grey
+ heading: Updates
items:
+ - heading: Introducing the Model Optimization Toolkit
+ image_path: /ecosystem/images/tf-logo-card-16x9.png
+ path: https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3
+ buttons:
+ - label: Read on TensorFlow blog
+ path: https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3
+ - heading: East Africa Cassava App
+ image_path: ./images/landing-page/detect_crop_disease_in_africa.png
+ path: https://heartbeat.fritz.ai/community-spotlight-nuru-a-mobile-app-by-plantvillage-to-detect-crop-disease-in-africa-28d142bf63d5
+ buttons:
+ - label: Read more
+ path: https://heartbeat.fritz.ai/community-spotlight-nuru-a-mobile-app-by-plantvillage-to-detect-crop-disease-in-africa-28d142bf63d5
- heading: Using TensorFlow Lite on Android
image_path: /ecosystem/images/tf-logo-card-16x9.png
path: https://medium.com/tensorflow/using-tensorflow-lite-on-android-9bbc9cb7d69d
buttons:
- label: Read on TensorFlow blog
path: https://medium.com/tensorflow/using-tensorflow-lite-on-android-9bbc9cb7d69d
+
+ - classname: devsite-landing-row-cards
+ background: grey
+ items:
- heading: TensorFlow Lite at the Dev Summit
youtube_id: FAMfy7izB6A
buttons:
@@ -66,3 +215,4 @@ landing_page:
buttons:
- label: View on GitHub
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite
+ - classname: devsite-landing-row-item-hidden
diff --git a/tensorflow/contrib/lite/g3doc/_project.yaml b/tensorflow/contrib/lite/g3doc/_project.yaml
index b39666516b..3ce6986396 100644
--- a/tensorflow/contrib/lite/g3doc/_project.yaml
+++ b/tensorflow/contrib/lite/g3doc/_project.yaml
@@ -1,10 +1,10 @@
name: TensorFlow Lite
-breadcrumb_name: Mobile
-home_url: /mobile/
+breadcrumb_name: TensorFlow Lite
+home_url: /lite/
parent_project_metadata_path: /_project.yaml
description: >
TensorFlow Lite is a lightweight solution for mobile and embedded devices.
-use_site_branding: True
-hide_from_products_list: True
+use_site_branding: true
+hide_from_products_list: true
content_license: cc3-apache2
buganizer_id: 316308
diff --git a/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml b/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml
deleted file mode 100644
index 1e1c44c692..0000000000
--- a/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml
+++ /dev/null
@@ -1,6 +0,0 @@
-# Automatically generated file; please do not edit
-toc:
- - title: TensorFlow Lite
- section:
- - title: Overview
- path: /mobile/api_docs/python/
diff --git a/tensorflow/contrib/lite/g3doc/devguide.md b/tensorflow/contrib/lite/g3doc/devguide.md
index 90e7915c52..0eed516000 100644
--- a/tensorflow/contrib/lite/g3doc/devguide.md
+++ b/tensorflow/contrib/lite/g3doc/devguide.md
@@ -1,5 +1,4 @@
-
-# Developer Guide
+# TF Lite Developer Guide
Using a TensorFlow Lite model in your mobile app requires multiple
considerations: you must choose a pre-trained or custom model, convert the model
@@ -55,7 +54,7 @@ both floating point and quantized inference.
### Train a custom model
A developer may choose to train a custom model using Tensorflow (see the
-[TensorFlow tutorials](../../tutorials/) for examples of building and training
+[TensorFlow tutorials](../tutorials/) for examples of building and training
models). If you have already written a model, the first step is to export this
to a `tf.GraphDef` file. This is required because some formats do not store the
model structure outside the code, and we must communicate with other parts of the
@@ -205,7 +204,7 @@ The open source Android demo app uses the JNI interface and is available
[on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app).
You can also download a
[prebuilt APK](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk).
-See the <a href="../demo_android.md">Android demo</a> guide for details.
+See the <a href="./demo_android.md">Android demo</a> guide for details.
The <a href="./android_build.md">Android mobile</a> guide has instructions for
installing TensorFlow on Android and setting up `bazel` and Android Studio.
@@ -214,7 +213,7 @@ installing TensorFlow on Android and setting up `bazel` and Android Studio.
To integrate a TensorFlow model in an iOS app, see the
[TensorFlow Lite for iOS](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/ios.md)
-guide and <a href="../demo_ios.md">iOS demo</a> guide.
+guide and <a href="./demo_ios.md">iOS demo</a> guide.
#### Core ML support
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.png
new file mode 100644
index 0000000000..ced0872ab2
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png b/tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png
new file mode 100644
index 0000000000..45b3b4f6fe
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.png
new file mode 100644
index 0000000000..bc1bf6e1e7
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.png b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.png
new file mode 100644
index 0000000000..d76fca86a9
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.png
new file mode 100644
index 0000000000..f1a93ab763
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.png
new file mode 100644
index 0000000000..21aa2c84ea
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.png
new file mode 100644
index 0000000000..b6b3d14df9
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.png b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.png
new file mode 100644
index 0000000000..b3e46d4bd8
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.png
new file mode 100644
index 0000000000..35bfd97373
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.png
new file mode 100644
index 0000000000..4333426dfe
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.png
new file mode 100644
index 0000000000..6ec412c75c
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.png
new file mode 100644
index 0000000000..f408f9024b
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/ios.md b/tensorflow/contrib/lite/g3doc/ios.md
index a83d2c8fec..3b9fcca811 100644
--- a/tensorflow/contrib/lite/g3doc/ios.md
+++ b/tensorflow/contrib/lite/g3doc/ios.md
@@ -1,5 +1,10 @@
-# TensorFlow Lite for iOS
+# Build TensorFlow Lite for iOS
+
+This document describes how to build TensorFlow Lite iOS library. If you just
+want to use it, the easiest way is using the TensorFlow Lite CocoaPod releases.
+See [TensorFlow Lite iOS Demo](demo_ios.md) for examples.
+
## Building
diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md
index a4267eee4c..279764ce96 100644
--- a/tensorflow/contrib/lite/g3doc/models.md
+++ b/tensorflow/contrib/lite/g3doc/models.md
@@ -1,6 +1,23 @@
# List of Hosted Models
+# AutoML mobile image classification models (Float Models)
+
+Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^
+------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ---------------------:
+MnasNet_0.50_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_0.5_224_09_07_2018.tgz) | 8.5 Mb | 68.03% | 87.79% | 37 ms
+MnasNet_0.75_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_0.75_224_09_07_2018.tgz) | 12 Mb | 71.72% | 90.17% | 61 ms
+MnasNet_1.0_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_224_09_07_2018.tgz) | 17 Mb | 74.08% | 91.75% | 93 ms
+MnasNet_1.3_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.3_224_09_07_2018.tgz) | 24 Mb | 75.24% | 92.55% | 152 ms
+MnasNet_1.0_96| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_96_09_07_2018.tgz) | 17 Mb | 62.33% | 83.98% | 23 ms
+MnasNet_1.0_128| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_128_09_07_2018.tgz) | 17 Mb | 67.32% | 87.70% | 34 ms
+MnasNet_1.0_160| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_160_09_07_2018.tgz) | 17 Mb | 70.63% | 89.58% | 51 ms
+MnasNet_1.0_192| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_192_09_07_2018.tgz) | 17 Mb | 72.56% | 90.76% | 70 ms
+MnasNet_1.0_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_224_09_07_2018.tgz) | 17 Mb | 74.08% | 91.75% | 93 ms
+
+^ Performance numbers are generated on Pixel-1 using single thread large BIG core.
+
+
## Image classification (Float Models)
Model Name | Paper_Model_Files^ | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^^ | Tensorflow Performance
diff --git a/tensorflow/contrib/lite/g3doc/overview.md b/tensorflow/contrib/lite/g3doc/overview.md
index 8cf43496df..9d035a6921 100644
--- a/tensorflow/contrib/lite/g3doc/overview.md
+++ b/tensorflow/contrib/lite/g3doc/overview.md
@@ -25,7 +25,7 @@ models.
TensorFlow Lite defines a new model file format, based on
[FlatBuffers](https://google.github.io/flatbuffers/). FlatBuffers is an
-open-sourced, efficient cross platform serialization library. It is similar to
+efficient open-source cross-platform serialization library. It is similar to
[protocol buffers](https://developers.google.com/protocol-buffers/?hl=en), but
the primary difference is that FlatBuffers does not need a parsing/unpacking
step to a secondary representation before you can access data, often coupled
diff --git a/tensorflow/contrib/lite/g3doc/performance.md b/tensorflow/contrib/lite/g3doc/performance.md
index 28cb6aba6e..0ae9400068 100644
--- a/tensorflow/contrib/lite/g3doc/performance.md
+++ b/tensorflow/contrib/lite/g3doc/performance.md
@@ -1,174 +1,38 @@
-# Performance
+# Performance best practices
-This document lists TensorFlow Lite performance benchmarks when running well
-known models on some Android and iOS devices.
+Mobile and embedded devices have limited computational resources and it is important to keep your application resource efficient. We have compiled a list of best practices and strategies you can use to optimize your model and application when using Tensorflow Lite.
-These performance benchmark numbers were generated with the
-[Android TFLite benchmark binary](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark)
-and the [iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios).
+## Choose the most efficient model for the problem
+Some models may be too large to run on embedded devices. Instead of large models it is better to use a slightly less precise but smaller model for embedded devices. Smaller models not only use less disk space and memory but are generally faster and more energy efficient. One example of models optimized for mobile devices are [MobileNets](https://arxiv.org/abs/1704.04861), which are optimized for mobile vision applications. Tensorflow Lite [models page](models.md) lists several other models that have been optimized specifically for mobile and embedded devices.
-# Android performance benchmarks
+You can retrain the listed models on your own dataset by using transfer learning. Check out our transfer learning tutorial for
+[image classification] (https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0) and
+ [object detection](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193).
-For Android benchmarks, the CPU affinity is set to use big cores on the device to
-reduce variance (see [details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#reducing-variance-between-runs-on-android)).
-It assumes that models were download and unzipped to the
-`/data/local/tmp/tflite_models` directory. The benchmark binary is built
-using [these instructions](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#on-android)
-and assumed in the `/data/local/tmp` directory.
+## Profile your model
+Before starting any optimization, it is a good practice to profile and benchmark your model. Tensorflow Lite [benchmarking tool](../tools/benchmark) has a built-in profiler that shows per operator profiling statistics. This can help in understanding performance bottlenecks and which operators dominate the computation time.
-To run the benchmark:
+## Profile and optimize operators in the graph
+If a particular operator appears frequently in the model and based on profiling you find the operator consuming the most amount of time, you can look into optimizing the operator.
+ This scenario should be rare as Tensorflow Lite has optimized versions for most ops. However you may be able to write a faster version of a custom op, if you know the constraints in which the operator is executed. Check out our [custom operator documentation](custom_operators.md).
-```
-adb shell taskset ${CPU_MASK} /data/local/tmp/benchmark_model \
- --num_threads=1 \
- --graph=/data/local/tmp/tflite_models/${GRAPH} \
- --warmup_runs=1 \
- --num_runs=50 \
- --use_nnapi=false
-```
+## Quantize your model
+If your model uses floating point weights or activations then it may be possible to reduce the size of model up to ~4x by using quantization and other model optimizations. Check out our [model optimization toolkit](https://www.tensorflow.org/performance/model_optimization) for details about optimizing your model. Fully quantized models can be remarkably power efficient as well.
-Here, `${GRAPH}` is the name of model and `${CPU_MASK}` is the CPU affinity
-chosen according to the following table:
+## Tweak the number of threads
+Tensorflow Lite supports multi-threaded kernels for many operators. You can increase the number of threads and speed up execution of operators. Increasing the number of threads will however make your model use more resources and power. For some applications latency may be more important than energy efficiency. You can increase the number of threads by setting the number of [interpreter](../interpreter.h) threads.
-Device | CPU_MASK |
--------| ----------
-Pixel 2 | f0 |
-Pixel xl | 0c |
+## Eliminate redundant copies
+Tensorflow Lite is optimized to reduce redundant copies. The APIs allow user to [mmap a model file](https://github.com/tensorflow/tensorflow/blob/9982fd6c8831cbd2f58954f79ea71f26660393bc/tensorflow/contrib/lite/model.h#L152) and avoid copies. If your application is not careful, there can be redundant copies when feeding the input to the model and reading output from the model. Make sure to eliminate redundant copies. If you are using higher level APIs like Java API, make sure to carefully check the documentation for performance caveats. For example, the Java API is a lot faster if ByteBuffers are used as [inputs](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java#L151).
-<table>
- <thead>
- <tr>
- <th>Model Name</th>
- <th>Device </th>
- <th>Mean inference time (std dev)</th>
- </tr>
- </thead>
- <tr>
- <td rowspan = 2>
- <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
- </td>
- <td>Pixel 2 </td>
- <td>166.5 ms (2.6 ms)</td>
- </tr>
- <tr>
- <td>Pixel xl </td>
- <td>122.9 ms (1.8 ms) </td>
- </tr>
- <tr>
- <td rowspan = 2>
- <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz">Mobilenet_1.0_224 (quant)</a>
- </td>
- <td>Pixel 2 </td>
- <td>69.5 ms (0.9 ms)</td>
- </tr>
- <tr>
- <td>Pixel xl </td>
- <td>78.9 ms (2.2 ms) </td>
- </tr>
- <tr>
- <td rowspan = 2>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz">NASNet mobile</a>
- </td>
- <td>Pixel 2 </td>
- <td>273.8 ms (3.5 ms)</td>
- </tr>
- <tr>
- <td>Pixel xl </td>
- <td>210.8 ms (4.2 ms)</td>
- </tr>
- <tr>
- <td rowspan = 2>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz">SqueezeNet</a>
- </td>
- <td>Pixel 2 </td>
- <td>234.0 ms (2.1 ms)</td>
- </tr>
- <tr>
- <td>Pixel xl </td>
- <td>158.0 ms (2.1 ms)</td>
- </tr>
- <tr>
- <td rowspan = 2>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz">Inception_ResNet_V2</a>
- </td>
- <td>Pixel 2 </td>
- <td>2846.0 ms (15.0 ms)</td>
- </tr>
- <tr>
- <td>Pixel xl </td>
- <td>1973.0 ms (15.0 ms) </td>
- </tr>
- <tr>
- <td rowspan = 2>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz">Inception_V4</a>
- </td>
- <td>Pixel 2 </td>
- <td>3180.0 ms (11.7 ms)</td>
- </tr>
- <tr>
- <td>Pixel xl </td>
- <td>2262.0 ms (21.0 ms) </td>
- </tr>
+## Profile your application with platform specific tools
+Platform specific tools like [Android profiler](https://developer.android.com/studio/profile/android-profiler) and [Instruments](https://help.apple.com/instruments/mac/current/) provide a wealth of profiling information that can be used to debug your app. Sometimes the performance bug may be not in the model but in parts of application code that interact with the model. Make sure to familiarize yourself with platform specific profiling tools and best practices for your platform.
- </table>
+## Use hardware accelerators available on the device
+Tensorflow Lite is working on adding support for accelerators like GPU and provides acceleration through [NNAPI](https://developer.android.com/ndk/guides/neuralnetworks/) on Android.
+You can utilize these hardware accelerator backends to improve the speed and efficiency of your model. To enable NNAPI call [UseNNAPI](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/interpreter.h#L334) on the interpreter instance.
-# iOS benchmarks
-
-To run iOS benchmarks, the [benchmark
-app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios)
-was modified to include the appropriate model and `benchmark_params.json` was
-modified to set `num_threads` to 1.
-
-<table>
- <thead>
- <tr>
- <th>Model Name</th>
- <th>Device </th>
- <th>Mean inference time (std dev)</th>
- </tr>
- </thead>
- <tr>
- <td>
- <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
- </td>
- <td>iPhone 8 </td>
- <td>32.2 ms (0.8 ms)</td>
- </tr>
- <tr>
- <td>
- <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz)">Mobilenet_1.0_224 (quant)</a>
- </td>
- <td>iPhone 8 </td>
- <td>24.4 ms (0.8 ms)</td>
- </tr>
- <tr>
- <td>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz">NASNet mobile</a>
- </td>
- <td>iPhone 8 </td>
- <td>60.3 ms (0.6 ms)</td>
- </tr>
- <tr>
- <td>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz">SqueezeNet</a>
- </td>
- <td>iPhone 8 </td>
- <td>44.3 (0.7 ms)</td>
- </tr>
- <tr>
- <td>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz">Inception_ResNet_V2</a>
- </td>
- <td>iPhone 8</td>
- <td>562.4 ms (18.2 ms)</td>
- </tr>
- <tr>
- <td>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz">Inception_V4</a>
- </td>
- <td>iPhone 8 </td>
- <td>661.0 ms (29.2 ms)</td>
- </tr>
- </table>
+## Need more help
+The Tensorflow team is happy to help diagnose and address specific performance issues you may be facing. Please file a bug on [github](https://github.com/tensorflow/tensorflow/issues) with details of the issue.
diff --git a/tensorflow/contrib/lite/g3doc/performance_benchmarks.md b/tensorflow/contrib/lite/g3doc/performance_benchmarks.md
new file mode 100644
index 0000000000..28cb6aba6e
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/performance_benchmarks.md
@@ -0,0 +1,174 @@
+
+# Performance
+
+This document lists TensorFlow Lite performance benchmarks when running well
+known models on some Android and iOS devices.
+
+These performance benchmark numbers were generated with the
+[Android TFLite benchmark binary](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark)
+and the [iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios).
+
+# Android performance benchmarks
+
+For Android benchmarks, the CPU affinity is set to use big cores on the device to
+reduce variance (see [details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#reducing-variance-between-runs-on-android)).
+
+It assumes that models were download and unzipped to the
+`/data/local/tmp/tflite_models` directory. The benchmark binary is built
+using [these instructions](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#on-android)
+and assumed in the `/data/local/tmp` directory.
+
+To run the benchmark:
+
+```
+adb shell taskset ${CPU_MASK} /data/local/tmp/benchmark_model \
+ --num_threads=1 \
+ --graph=/data/local/tmp/tflite_models/${GRAPH} \
+ --warmup_runs=1 \
+ --num_runs=50 \
+ --use_nnapi=false
+```
+
+Here, `${GRAPH}` is the name of model and `${CPU_MASK}` is the CPU affinity
+chosen according to the following table:
+
+Device | CPU_MASK |
+-------| ----------
+Pixel 2 | f0 |
+Pixel xl | 0c |
+
+<table>
+ <thead>
+ <tr>
+ <th>Model Name</th>
+ <th>Device </th>
+ <th>Mean inference time (std dev)</th>
+ </tr>
+ </thead>
+ <tr>
+ <td rowspan = 2>
+ <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>166.5 ms (2.6 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>122.9 ms (1.8 ms) </td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz">Mobilenet_1.0_224 (quant)</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>69.5 ms (0.9 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>78.9 ms (2.2 ms) </td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz">NASNet mobile</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>273.8 ms (3.5 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>210.8 ms (4.2 ms)</td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz">SqueezeNet</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>234.0 ms (2.1 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>158.0 ms (2.1 ms)</td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz">Inception_ResNet_V2</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>2846.0 ms (15.0 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>1973.0 ms (15.0 ms) </td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz">Inception_V4</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>3180.0 ms (11.7 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>2262.0 ms (21.0 ms) </td>
+ </tr>
+
+ </table>
+
+# iOS benchmarks
+
+To run iOS benchmarks, the [benchmark
+app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios)
+was modified to include the appropriate model and `benchmark_params.json` was
+modified to set `num_threads` to 1.
+
+<table>
+ <thead>
+ <tr>
+ <th>Model Name</th>
+ <th>Device </th>
+ <th>Mean inference time (std dev)</th>
+ </tr>
+ </thead>
+ <tr>
+ <td>
+ <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>32.2 ms (0.8 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz)">Mobilenet_1.0_224 (quant)</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>24.4 ms (0.8 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz">NASNet mobile</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>60.3 ms (0.6 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz">SqueezeNet</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>44.3 (0.7 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz">Inception_ResNet_V2</a>
+ </td>
+ <td>iPhone 8</td>
+ <td>562.4 ms (18.2 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz">Inception_V4</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>661.0 ms (29.2 ms)</td>
+ </tr>
+ </table>
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index 8660d29855..b0dfb0fed1 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -866,6 +866,17 @@ Outputs {
}
```
+**ZEROS_LIKE**
+
+```
+Inputs {
+ 0: a tensor
+}
+Outputs {
+ 0: A tensor of the same shape and type as x but filled with zeros
+}
+```
+
And these are TensorFlow Lite operations that are present but not ready for
custom models yet:
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
index c7cdee07de..b0f32a8d6c 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
@@ -93,7 +93,7 @@ requires some knowledge of build systems and Android developer tools, but we'll
guide you through the basics here.
- First, follow our instructions for
- <a href="http://www.tensorflow.org/install/install_sources">installing from sources</a>.
+ <a href="http://www.tensorflow.org/install/source">installing from sources</a>.
This will also guide you through installing Bazel and cloning the
TensorFlow code.
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/index.md b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
index d003bb2f38..49ad35d4e6 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/index.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
@@ -4,7 +4,7 @@
TensorFlow was designed to be a good deep learning solution for mobile
platforms. Currently we have two solutions for deploying machine learning
applications on mobile and embedded devices: TensorFlow for Mobile and
-<a href="../index.md">TensorFlow Lite</a>.
+<a href="../../lite">TensorFlow Lite</a>.
## TensorFlow Lite versus TensorFlow Mobile
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 2657bcd42b..88e41ffc55 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -451,16 +451,15 @@ TfLiteStatus Interpreter::AllocateTensors() {
// Reset the variable tensors to zero after (re)allocating the tensors.
// Developers shouldn't rely on the side effect of this function to reset
- // variable tesnsors. They should call `ResetVariableTensorsToZero` directly
+ // variable tesnsors. They should call `ResetVariableTensors` directly
// instead.
- ResetVariableTensorsToZero();
+ ResetVariableTensors();
return kTfLiteOk;
}
-// TODO(ycling): Consider to provide other functions to initialize variable
-// tensors to non-zero values.
-TfLiteStatus Interpreter::ResetVariableTensorsToZero() {
+// TODO(ycling): Support non-zero default values.
+TfLiteStatus Interpreter::ResetVariableTensors() {
for (auto& tensor : tensors_) {
if (!tensor.is_variable) {
continue;
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index aa2bc4def6..7ef736d01b 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -421,9 +421,12 @@ class Interpreter {
allow_buffer_handle_output_ = allow_buffer_handle_output;
}
- // Reset all variable tensors to zero.
+ // Reset all variable tensors to the default value.
+ // If a variable tensor doesn't have a buffer, reset it to zero.
+ // TODO(b/115961645): Implement - If a variable tensor has a buffer, reset it
+ // to the value of the buffer.
// WARNING: This is an experimental API and subject to change.
- TfLiteStatus ResetVariableTensorsToZero();
+ TfLiteStatus ResetVariableTensors();
// Retrieve an operator's description of its work, for profiling purposes.
const char* OpProfilingString(const TfLiteRegistration& op_reg,
diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md
index 6a3f0651d0..c04b2a6194 100644
--- a/tensorflow/contrib/lite/java/demo/README.md
+++ b/tensorflow/contrib/lite/java/demo/README.md
@@ -1,4 +1,6 @@
-# TF Lite Android App
+# TF Lite Android Image Classifier App Example
+
+A simple Android example that demonstrates image classification using the camera.
## Building in Android Studio with TensorFlow Lite AAR from JCenter.
The build.gradle is configured to use TensorFlow Lite's nightly build.
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
index 4f5662bc2d..3596e42011 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
@@ -58,9 +58,9 @@ import android.view.View;
import android.view.ViewGroup;
import android.widget.CompoundButton;
import android.widget.NumberPicker;
-import android.widget.ToggleButton;
import android.widget.TextView;
import android.widget.Toast;
+import android.widget.ToggleButton;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
@@ -305,22 +305,24 @@ public class Camera2BasicFragment extends Fragment
textView = (TextView) view.findViewById(R.id.text);
toggle = (ToggleButton) view.findViewById(R.id.button);
- toggle.setOnCheckedChangeListener(new CompoundButton.OnCheckedChangeListener() {
- public void onCheckedChanged(CompoundButton buttonView, boolean isChecked) {
- classifier.setUseNNAPI(isChecked);
- }
- });
+ toggle.setOnCheckedChangeListener(
+ new CompoundButton.OnCheckedChangeListener() {
+ public void onCheckedChanged(CompoundButton buttonView, boolean isChecked) {
+ backgroundHandler.post(() -> classifier.setUseNNAPI(isChecked));
+ }
+ });
np = (NumberPicker) view.findViewById(R.id.np);
np.setMinValue(1);
np.setMaxValue(10);
np.setWrapSelectorWheel(true);
- np.setOnValueChangedListener(new NumberPicker.OnValueChangeListener() {
- @Override
- public void onValueChange(NumberPicker picker, int oldVal, int newVal){
- classifier.setNumThreads(newVal);
- }
- });
+ np.setOnValueChangedListener(
+ new NumberPicker.OnValueChangeListener() {
+ @Override
+ public void onValueChange(NumberPicker picker, int oldVal, int newVal) {
+ backgroundHandler.post(() -> classifier.setNumThreads(newVal));
+ }
+ });
}
/** Load the model and labels. */
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
index 7bb6afd9d8..2d11a57434 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
@@ -59,9 +59,15 @@ public abstract class ImageClassifier {
private static final int DIM_PIXEL_SIZE = 3;
- /* Preallocated buffers for storing image data in. */
+ /** Preallocated buffers for storing image data in. */
private int[] intValues = new int[getImageSizeX() * getImageSizeY()];
+ /** Options for configuring the Interpreter. */
+ private final Interpreter.Options tfliteOptions = new Interpreter.Options();
+
+ /** The loaded TensorFlow Lite model. */
+ private MappedByteBuffer tfliteModel;
+
/** An instance of the driver class to run model inference with Tensorflow Lite. */
protected Interpreter tflite;
@@ -89,7 +95,8 @@ public abstract class ImageClassifier {
/** Initializes an {@code ImageClassifier}. */
ImageClassifier(Activity activity) throws IOException {
- tflite = new Interpreter(loadModelFile(activity));
+ tfliteModel = loadModelFile(activity);
+ tflite = new Interpreter(tfliteModel, tfliteOptions);
labelList = loadLabelList(activity);
imgData =
ByteBuffer.allocateDirect(
@@ -150,20 +157,28 @@ public abstract class ImageClassifier {
}
}
+ private void recreateInterpreter() {
+ if (tflite != null) {
+ tflite.close();
+ tflite = new Interpreter(tfliteModel, tfliteOptions);
+ }
+ }
+
public void setUseNNAPI(Boolean nnapi) {
- if (tflite != null)
- tflite.setUseNNAPI(nnapi);
+ tfliteOptions.setUseNNAPI(nnapi);
+ recreateInterpreter();
}
- public void setNumThreads(int num_threads) {
- if (tflite != null)
- tflite.setNumThreads(num_threads);
+ public void setNumThreads(int numThreads) {
+ tfliteOptions.setNumThreads(numThreads);
+ recreateInterpreter();
}
/** Closes tflite to release resources. */
public void close() {
tflite.close();
tflite = null;
+ tfliteModel = null;
}
/** Reads label list from Assets. */
diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD
index 781289ceb2..bb0be04ca2 100644
--- a/tensorflow/contrib/lite/java/ovic/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/BUILD
@@ -44,6 +44,7 @@ java_binary(
android_library(
name = "ovicbenchmarkerlib",
srcs = [
+ "src/main/java/org/tensorflow/ovic/OvicBenchmarker.java",
"src/main/java/org/tensorflow/ovic/OvicClassifier.java",
"src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java",
],
diff --git a/tensorflow/contrib/lite/java/ovic/README.md b/tensorflow/contrib/lite/java/ovic/README.md
index 26349347fa..df77bfaab3 100644
--- a/tensorflow/contrib/lite/java/ovic/README.md
+++ b/tensorflow/contrib/lite/java/ovic/README.md
@@ -4,7 +4,7 @@ This folder contains building code for track one of the [Low Power ImageNet Reco
## Pre-requisite
-Follow the steps [here](https://www.tensorflow.org/mobile/tflite/demo_android) to install Tensorflow, Bazel, and the Android NDK and SDK.
+Follow the steps [here](https://www.tensorflow.org/lite/demo_android) to install Tensorflow, Bazel, and the Android NDK and SDK.
## Test the benchmarker:
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
index a8d751ade2..b2e3a9bd7d 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
@@ -6,7 +6,6 @@ licenses(["notice"]) # Apache 2.0
android_binary(
name = "ovic_benchmarker_binary",
srcs = [
- "OvicBenchmarker.java",
"OvicBenchmarkerActivity.java",
],
assets = [
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java
index 59457c308a..4adf94aeb6 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java
@@ -34,8 +34,10 @@ import java.io.InputStream;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.text.DecimalFormat;
+import org.tensorflow.ovic.OvicBenchmarker;
import org.tensorflow.ovic.OvicSingleImageResult;
+
/** Class that benchmark image classifier models. */
public class OvicBenchmarkerActivity extends Activity {
/** Tag for the {@link Log}. */
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java
index 113ab74a20..4cda258bee 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-package ovic.demo.app;
+package org.tensorflow.ovic;
import android.graphics.Bitmap;
import android.os.SystemClock;
@@ -22,8 +22,6 @@ import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
-import org.tensorflow.ovic.OvicClassifier;
-import org.tensorflow.ovic.OvicSingleImageResult;
/**
* Class that benchmarks image classifier models.
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
index 4cf51bb0fa..fd610b054f 100644
--- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
@@ -74,7 +74,7 @@ public class OvicClassifier {
}
labelList = loadLabelList(labelInputStream);
// OVIC uses one thread for CPU inference.
- tflite = new Interpreter(model, 1);
+ tflite = new Interpreter(model, new Interpreter.Options().setNumThreads(1));
inputDims = TestHelper.getInputDims(tflite, 0);
if (inputDims.length != 4) {
throw new RuntimeException("The model's input dimensions must be 4 (BWHC).");
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index b84720ae8e..5cc6e754f3 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -56,16 +56,47 @@ import org.checkerframework.checker.nullness.qual.NonNull;
*/
public final class Interpreter implements AutoCloseable {
+ /** An options class for controlling runtime interpreter behavior. */
+ public static class Options {
+ public Options() {}
+
+ /**
+ * Sets the number of threads to be used for ops that support multi-threading. Defaults to a
+ * platform-dependent value.
+ */
+ public Options setNumThreads(int numThreads) {
+ this.numThreads = numThreads;
+ return this;
+ }
+
+ /** Sets whether to use NN API (if available) for op execution. Defaults to false (disabled). */
+ public Options setUseNNAPI(boolean useNNAPI) {
+ this.useNNAPI = useNNAPI;
+ return this;
+ }
+
+ /**
+ * Sets whether to allow float16 precision for FP32 calculation when possible. Defaults to false
+ * (disallow).
+ * WARNING: This is an experimental API and subject to change.
+ */
+ public Options setAllowFp16PrecisionForFp32(boolean allow) {
+ this.allowFp16PrecisionForFp32 = allow;
+ return this;
+ }
+
+ int numThreads = -1;
+ boolean useNNAPI = false;
+ boolean allowFp16PrecisionForFp32 = false;
+ }
+
/**
* Initializes a {@code Interpreter}
*
* @param modelFile: a File of a pre-trained TF Lite model.
*/
public Interpreter(@NonNull File modelFile) {
- if (modelFile == null) {
- return;
- }
- wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath());
+ this(modelFile, /*options = */ null);
}
/**
@@ -73,12 +104,22 @@ public final class Interpreter implements AutoCloseable {
*
* @param modelFile: a file of a pre-trained TF Lite model
* @param numThreads: number of threads to use for inference
+ * @deprecated Prefer using the {@link #Interpreter(File,Options)} constructor. This method will
+ * be removed in a future release.
*/
+ @Deprecated
public Interpreter(@NonNull File modelFile, int numThreads) {
- if (modelFile == null) {
- return;
- }
- wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), numThreads);
+ this(modelFile, new Options().setNumThreads(numThreads));
+ }
+
+ /**
+ * Initializes a {@code Interpreter} and specifies the number of threads used for inference.
+ *
+ * @param modelFile: a file of a pre-trained TF Lite model
+ * @param options: a set of options for customizing interpreter behavior
+ */
+ public Interpreter(@NonNull File modelFile, Options options) {
+ wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), options);
}
/**
@@ -89,7 +130,7 @@ public final class Interpreter implements AutoCloseable {
* direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
*/
public Interpreter(@NonNull ByteBuffer byteBuffer) {
- wrapper = new NativeInterpreterWrapper(byteBuffer);
+ this(byteBuffer, /* options= */ null);
}
/**
@@ -99,9 +140,13 @@ public final class Interpreter implements AutoCloseable {
* <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
* {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a
* direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
+ *
+ * @deprecated Prefer using the {@link #Interpreter(ByteBuffer,Options)} constructor. This method
+ * will be removed in a future release.
*/
+ @Deprecated
public Interpreter(@NonNull ByteBuffer byteBuffer, int numThreads) {
- wrapper = new NativeInterpreterWrapper(byteBuffer, numThreads);
+ this(byteBuffer, new Options().setNumThreads(numThreads));
}
/**
@@ -109,20 +154,25 @@ public final class Interpreter implements AutoCloseable {
*
* <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code
* Interpreter}.
+ *
+ * @deprecated Prefer using the {@link #Interpreter(ByteBuffer,Options)} constructor. This method
+ * will be removed in a future release.
*/
+ @Deprecated
public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer) {
- wrapper = new NativeInterpreterWrapper(mappedByteBuffer);
+ this(mappedByteBuffer, /* options= */ null);
}
/**
- * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file and
- * specifies the number of threads used for inference.
+ * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file and a set of custom
+ * {@link #Options}.
*
- * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code
- * Interpreter}.
+ * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
+ * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a
+ * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
*/
- public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer, int numThreads) {
- wrapper = new NativeInterpreterWrapper(mappedByteBuffer, numThreads);
+ public Interpreter(@NonNull ByteBuffer byteBuffer, Options options) {
+ wrapper = new NativeInterpreterWrapper(byteBuffer, options);
}
/**
@@ -232,20 +282,34 @@ public final class Interpreter implements AutoCloseable {
/**
* Returns native inference timing.
- * <p>IllegalArgumentException will be thrown if the model is not initialized by the
- * {@link Interpreter}.
+ *
+ * <p>IllegalArgumentException will be thrown if the model is not initialized by the {@link
+ * Interpreter}.
*/
public Long getLastNativeInferenceDurationNanoseconds() {
checkNotClosed();
return wrapper.getLastNativeInferenceDurationNanoseconds();
}
- /** Turns on/off Android NNAPI for hardware acceleration when it is available. */
+ /**
+ * Turns on/off Android NNAPI for hardware acceleration when it is available.
+ *
+ * @deprecated Prefer using {@link Options#setUseNNAPI(boolean)} directly for enabling NN API.
+ * This method will be removed in a future release.
+ */
+ @Deprecated
public void setUseNNAPI(boolean useNNAPI) {
checkNotClosed();
wrapper.setUseNNAPI(useNNAPI);
}
+ /**
+ * Sets the number of threads to be used for ops that support multi-threading.
+ *
+ * @deprecated Prefer using {@link Options#setNumThreads(int)} directly for controlling thread
+ * multi-threading. This method will be removed in a future release.
+ */
+ @Deprecated
public void setNumThreads(int numThreads) {
checkNotClosed();
wrapper.setNumThreads(numThreads);
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index fa25082304..9bc44bf797 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -23,7 +23,7 @@ import java.util.HashMap;
import java.util.Map;
/**
- * A wrapper wraps native interpreter and controls model execution.
+ * An internal wrapper that wraps native interpreter and controls model execution.
*
* <p><b>WARNING:</b> Resources consumed by the {@code NativeInterpreterWrapper} object must be
* explicitly freed by invoking the {@link #close()} method when the {@code
@@ -32,36 +32,32 @@ import java.util.Map;
final class NativeInterpreterWrapper implements AutoCloseable {
NativeInterpreterWrapper(String modelPath) {
- this(modelPath, /* numThreads= */ -1);
+ this(modelPath, /* options= */ null);
}
- NativeInterpreterWrapper(String modelPath, int numThreads) {
+ NativeInterpreterWrapper(String modelPath, Interpreter.Options options) {
+ if (options == null) {
+ options = new Interpreter.Options();
+ }
errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
modelHandle = createModel(modelPath, errorHandle);
- interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads);
+ interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads);
isMemoryAllocated = true;
inputTensors = new Tensor[getInputCount(interpreterHandle)];
outputTensors = new Tensor[getOutputCount(interpreterHandle)];
+ if (options.allowFp16PrecisionForFp32) {
+ setAllowFp16PrecisionForFp32(options.allowFp16PrecisionForFp32);
+ }
}
- /**
- * Initializes a {@code NativeInterpreterWrapper} with a {@code ByteBuffer}. The ByteBuffer should
- * not be modified after the construction of a {@code NativeInterpreterWrapper}. The {@code
- * ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a direct
- * {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
- */
NativeInterpreterWrapper(ByteBuffer byteBuffer) {
- this(byteBuffer, /* numThreads= */ -1);
+ this(byteBuffer, /* options= */ null);
}
- /**
- * Initializes a {@code NativeInterpreterWrapper} with a {@code ByteBuffer} and specifies the
- * number of inference threads. The ByteBuffer should not be modified after the construction of a
- * {@code NativeInterpreterWrapper}. The {@code ByteBuffer} can be either a {@code
- * MappedByteBuffer} that memory-maps a model file, or a direct {@code ByteBuffer} of
- * nativeOrder() that contains the bytes content of a model.
- */
- NativeInterpreterWrapper(ByteBuffer buffer, int numThreads) {
+ NativeInterpreterWrapper(ByteBuffer buffer, Interpreter.Options options) {
+ if (options == null) {
+ options = new Interpreter.Options();
+ }
if (buffer == null
|| (!(buffer instanceof MappedByteBuffer)
&& (!buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()))) {
@@ -72,10 +68,16 @@ final class NativeInterpreterWrapper implements AutoCloseable {
modelByteBuffer = buffer;
errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
- interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads);
+ interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads);
isMemoryAllocated = true;
inputTensors = new Tensor[getInputCount(interpreterHandle)];
outputTensors = new Tensor[getOutputCount(interpreterHandle)];
+ if (options.useNNAPI) {
+ setUseNNAPI(options.useNNAPI);
+ }
+ if (options.allowFp16PrecisionForFp32) {
+ setAllowFp16PrecisionForFp32(options.allowFp16PrecisionForFp32);
+ }
}
/** Releases resources associated with this {@code NativeInterpreterWrapper}. */
@@ -163,6 +165,10 @@ final class NativeInterpreterWrapper implements AutoCloseable {
useNNAPI(interpreterHandle, useNNAPI);
}
+ void setAllowFp16PrecisionForFp32(boolean allow) {
+ allowFp16PrecisionForFp32(interpreterHandle, allow);
+ }
+
void setNumThreads(int numThreads) {
numThreads(interpreterHandle, numThreads);
}
@@ -327,6 +333,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
private static native void numThreads(long interpreterHandle, int numThreads);
+ private static native void allowFp16PrecisionForFp32(long interpreterHandle, boolean allow);
+
private static native long createErrorReporter(int size);
private static native long createModel(String modelPathOrBuffer, long errorHandle);
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index fdcf00a0a0..abb7320bc5 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -59,7 +59,6 @@ std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) {
return outputs;
}
-
int getDataType(TfLiteType data_type) {
switch (data_type) {
case kTfLiteFloat32:
@@ -234,10 +233,18 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
}
JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32(
+ JNIEnv* env, jclass clazz, jlong handle, jboolean allow) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return;
+ interpreter->SetAllowFp16PrecisionForFp32(static_cast<bool>(allow));
+}
+
+JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jint num_threads) {
+ jclass clazz,
+ jlong handle,
+ jint num_threads) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return;
interpreter->SetNumThreads(static_cast<int>(num_threads));
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
index 06b35d77c8..aa809dff8a 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
@@ -120,6 +120,15 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
* Method:
+ * Signature: (JZ)V
+ */
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32(
+ JNIEnv* env, jclass clazz, jlong handle, jboolean allow);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
* Signature: (JI)V
*/
JNIEXPORT void JNICALL
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index 9070b788b6..a98fca0132 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -55,11 +55,23 @@ public final class InterpreterTest {
}
@Test
+ public void testInterpreterWithOptions() throws Exception {
+ Interpreter interpreter =
+ new Interpreter(MODEL_FILE, new Interpreter.Options().setNumThreads(2).setUseNNAPI(true));
+ assertThat(interpreter).isNotNull();
+ assertThat(interpreter.getInputTensorCount()).isEqualTo(1);
+ assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(interpreter.getOutputTensorCount()).isEqualTo(1);
+ assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ interpreter.close();
+ }
+
+ @Test
public void testRunWithMappedByteBufferModel() throws Exception {
Path path = MODEL_FILE.toPath();
FileChannel fileChannel =
(FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
- MappedByteBuffer mappedByteBuffer =
+ ByteBuffer mappedByteBuffer =
fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
Interpreter interpreter = new Interpreter(mappedByteBuffer);
float[] oneD = {1.23f, 6.54f, 7.81f};
@@ -106,7 +118,7 @@ public final class InterpreterTest {
byteBuffer.order(ByteOrder.nativeOrder());
fileChannel.read(byteBuffer);
try {
- Interpreter interpreter = new Interpreter(byteBuffer);
+ new Interpreter(byteBuffer);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
@@ -304,40 +316,16 @@ public final class InterpreterTest {
}
@Test
- public void testTurnOffNNAPI() throws Exception {
- Path path = MODEL_FILE.toPath();
- FileChannel fileChannel =
- (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
- MappedByteBuffer mappedByteBuffer =
- fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
- Interpreter interpreter = new Interpreter(mappedByteBuffer);
- interpreter.setUseNNAPI(true);
- float[] oneD = {1.23f, 6.54f, 7.81f};
- float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
- float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
- float[][][][] fourD = {threeD, threeD};
- float[][][][] parsedOutputs = new float[2][8][8][3];
- interpreter.run(fourD, parsedOutputs);
- float[] outputOneD = parsedOutputs[0][0][0];
- float[] expected = {3.69f, 19.62f, 23.43f};
- assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
- interpreter.setUseNNAPI(false);
- interpreter.run(fourD, parsedOutputs);
- outputOneD = parsedOutputs[0][0][0];
- assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
- interpreter.close();
- fileChannel.close();
- }
-
- @Test
public void testTurnOnNNAPI() throws Exception {
Path path = MODEL_FILE.toPath();
FileChannel fileChannel =
(FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
MappedByteBuffer mappedByteBuffer =
fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
- Interpreter interpreter = new Interpreter(mappedByteBuffer);
- interpreter.setUseNNAPI(true);
+ Interpreter interpreter =
+ new Interpreter(
+ mappedByteBuffer,
+ new Interpreter.Options().setUseNNAPI(true).setAllowFp16PrecisionForFp32(true));
float[] oneD = {1.23f, 6.54f, 7.81f};
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
index 9c4a5acd79..270bd6703a 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
@@ -63,6 +63,15 @@ public final class NativeInterpreterWrapperTest {
}
@Test
+ public void testConstructorWithOptions() {
+ NativeInterpreterWrapper wrapper =
+ new NativeInterpreterWrapper(
+ FLOAT_MODEL_PATH, new Interpreter.Options().setNumThreads(2).setUseNNAPI(true));
+ assertThat(wrapper).isNotNull();
+ wrapper.close();
+ }
+
+ @Test
public void testConstructorWithInvalidModel() {
try {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH);
diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
index 38b740021b..af20e3280b 100644
--- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
+++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
@@ -19,21 +19,6 @@ package org.tensorflow.lite;
public class TestHelper {
/**
- * Turns on/off NNAPI of an {@code Interpreter}.
- *
- * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code
- * IllegalArgumentException} will be thrown.
- * @param useNNAPI a boolean value indicating to turn on or off NNAPI.
- */
- public static void setUseNNAPI(Interpreter interpreter, boolean useNNAPI) {
- if (interpreter != null && interpreter.wrapper != null) {
- interpreter.wrapper.setUseNNAPI(useNNAPI);
- } else {
- throw new IllegalArgumentException("Interpreter has not initialized; Failed to setUseNNAPI.");
- }
- }
-
- /**
* Gets the last inference duration in nanoseconds. It returns null if there is no previous
* inference run or the last inference run failed.
*
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 40f28aeab4..daaf6714cc 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -223,6 +223,7 @@ cc_library(
"unidirectional_sequence_lstm.cc",
"unidirectional_sequence_rnn.cc",
"unpack.cc",
+ "zeros_like.cc",
],
hdrs = [
],
@@ -508,6 +509,7 @@ tf_cc_test(
":builtin_ops",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_absl//absl/memory",
"@com_google_googletest//:gtest",
],
)
@@ -1284,6 +1286,20 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "zeros_like_test",
+ size = "small",
+ srcs = ["zeros_like_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index b2d9b84979..cf9441aee3 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -348,18 +348,22 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
} break;
case kTfLiteInt16: {
- optimized_ops::Tanh(GetTensorData<int16_t>(input), GetTensorShape(input),
- data->input_left_shift,
- GetTensorData<int16_t>(output),
- GetTensorShape(output));
+ TanhParams params;
+ params.input_left_shift = data->input_left_shift;
+ optimized_ops::Tanh(params, GetTensorShape(input),
+ GetTensorData<int16_t>(input), GetTensorShape(output),
+ GetTensorData<int16_t>(output));
return kTfLiteOk;
} break;
case kTfLiteUInt8: {
- optimized_ops::Tanh(GetTensorData<uint8_t>(input), GetTensorShape(input),
- input->params.zero_point, data->input_range_radius,
- data->input_multiplier, data->input_left_shift,
- GetTensorData<uint8_t>(output),
- GetTensorShape(output));
+ TanhParams params;
+ params.input_zero_point = input->params.zero_point;
+ params.input_range_radius = data->input_range_radius;
+ params.input_multiplier = data->input_multiplier;
+ params.input_left_shift = data->input_left_shift;
+ optimized_ops::Tanh(params, GetTensorShape(input),
+ GetTensorData<uint8_t>(input), GetTensorShape(output),
+ GetTensorData<uint8_t>(output));
return kTfLiteOk;
} break;
default:
@@ -385,17 +389,21 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
break;
}
case kTfLiteInt16: {
+ LogisticParams params;
optimized_ops::Logistic(
- GetTensorData<int16>(input), GetTensorShape(input),
- GetTensorData<int16_t>(output), GetTensorShape(output));
+ params, GetTensorShape(input), GetTensorData<int16_t>(input),
+ GetTensorShape(output), GetTensorData<int16_t>(output));
break;
}
case kTfLiteUInt8: {
+ LogisticParams params;
+ params.input_zero_point = input->params.zero_point;
+ params.input_range_radius = data->input_range_radius;
+ params.input_multiplier = data->input_multiplier;
+ params.input_left_shift = data->input_left_shift;
optimized_ops::Logistic(
- GetTensorData<uint8_t>(input), GetTensorShape(input),
- input->params.zero_point, data->input_range_radius,
- data->input_multiplier, data->input_left_shift,
- GetTensorData<uint8_t>(output), GetTensorShape(output));
+ params, GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(output), GetTensorData<uint8_t>(output));
break;
}
default:
@@ -459,11 +467,13 @@ void Softmax3DFloat(const TfLiteTensor* input, TfLiteTensor* output,
const int batch_size = input->dims->data[0];
const int intermediate_size = input->dims->data[1];
const int input_size = input->dims->data[2];
+ SoftmaxParams op_params;
+ op_params.beta = params->beta;
optimized_ops::Softmax(
+ op_params, GetTensorShape({batch_size, intermediate_size, 1, input_size}),
GetTensorData<float>(input),
GetTensorShape({batch_size, intermediate_size, 1, input_size}),
- params->beta, GetTensorData<float>(output),
- GetTensorShape({batch_size, intermediate_size, 1, input_size}));
+ GetTensorData<float>(output));
}
void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
@@ -473,10 +483,14 @@ void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
// tensor is 4D in a special way. We will convert a (Y) shape into a (1,
// 1, 1, Y) shape.
const int input_size = input->dims->data[0];
- optimized_ops::Softmax(
- GetTensorData<uint8_t>(input), GetTensorShape({1, 1, 1, input_size}),
- data->input_multiplier, data->input_left_shift, data->diff_min,
- GetTensorData<uint8_t>(output), GetTensorShape({1, 1, 1, input_size}));
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
+ optimized_ops::Softmax(op_params, GetTensorShape({1, 1, 1, input_size}),
+ GetTensorData<uint8_t>(input),
+ GetTensorShape({1, 1, 1, input_size}),
+ GetTensorData<uint8_t>(output));
}
void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
@@ -486,11 +500,15 @@ void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
// 1, 1, Y) shape.
const int batch_size = input->dims->data[0];
const int input_size = input->dims->data[1];
- optimized_ops::Softmax(GetTensorData<uint8_t>(input),
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
+ optimized_ops::Softmax(op_params,
+ GetTensorShape({batch_size, 1, 1, input_size}),
+ GetTensorData<uint8_t>(input),
GetTensorShape({batch_size, 1, 1, input_size}),
- data->input_multiplier, data->input_left_shift,
- data->diff_min, GetTensorData<uint8_t>(output),
- GetTensorShape({batch_size, 1, 1, input_size}));
+ GetTensorData<uint8_t>(output));
}
void Softmax3DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
@@ -498,28 +516,36 @@ void Softmax3DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
const int batch_size = input->dims->data[0];
const int intermediate_size = input->dims->data[1];
const int input_size = input->dims->data[2];
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
optimized_ops::Softmax(
+ op_params, GetTensorShape({batch_size, intermediate_size, 1, input_size}),
GetTensorData<uint8_t>(input),
GetTensorShape({batch_size, intermediate_size, 1, input_size}),
- data->input_multiplier, data->input_left_shift, data->diff_min,
- GetTensorData<uint8_t>(output),
- GetTensorShape({batch_size, intermediate_size, 1, input_size}));
+ GetTensorData<uint8_t>(output));
}
// Takes a 4D tensor and perform softmax along the forth dimension.
void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params) {
- optimized_ops::Softmax(GetTensorData<float>(input), GetTensorShape(input),
- params->beta, GetTensorData<float>(output),
- GetTensorShape(output));
+ SoftmaxParams op_params;
+ op_params.beta = params->beta;
+ optimized_ops::Softmax(op_params, GetTensorShape(input),
+ GetTensorData<float>(input), GetTensorShape(output),
+ GetTensorData<float>(output));
}
void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
- optimized_ops::Softmax(GetTensorData<uint8_t>(input), GetTensorShape(input),
- data->input_multiplier, data->input_left_shift,
- data->diff_min, GetTensorData<uint8_t>(output),
- GetTensorShape(output));
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
+ optimized_ops::Softmax(op_params, GetTensorShape(input),
+ GetTensorData<uint8_t>(input), GetTensorShape(output),
+ GetTensorData<uint8_t>(output));
}
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
@@ -591,17 +617,20 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, 0);
switch (input->type) {
case kTfLiteFloat32:
+ SoftmaxParams op_params;
optimized_ops::LogSoftmax(
- GetTensorData<float>(input), GetTensorShape(input),
- GetTensorData<float>(output), GetTensorShape(output));
+ op_params, GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(output), GetTensorData<float>(output));
return kTfLiteOk;
case kTfLiteUInt8:
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.reverse_scaling_divisor = data->reverse_scaling_divisor;
+ op_params.reverse_scaling_right_shift = data->reverse_scaling_right_shift;
+ op_params.diff_min = data->diff_min;
optimized_ops::LogSoftmax(
- GetTensorData<uint8_t>(input), GetTensorShape(input),
- data->input_multiplier, data->input_left_shift,
- data->reverse_scaling_divisor, data->reverse_scaling_right_shift,
- data->diff_min, GetTensorData<uint8_t>(output),
- GetTensorShape(output));
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(output), GetTensorData<uint8_t>(output));
return kTfLiteOk;
default:
context->ReportError(context, "Only float32 supported currently., got %d",
diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
index 44ef587244..0d2d5e775f 100644
--- a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
+++ b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
namespace tflite {
namespace ops {
diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc
index 7346b9fd80..7e4ff6fc16 100644
--- a/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc
+++ b/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index 541f320138..66b947771c 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -770,51 +770,29 @@ TfLiteStatus EvalFloat(
}
// Loop through the sequence.
- if (forward_sequence) {
- for (int t = 0; t < max_time; t++) {
- const float* input_ptr = input->data.f + t * n_batch * n_input;
- float* output_ptr_time = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStepWithAuxInput(
- input_ptr, input_to_input_weights_ptr,
- input_to_forget_weights->data.f, input_to_cell_weights->data.f,
- input_to_output_weights->data.f, aux_input_ptr,
- aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
- aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
- recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
- recurrent_to_cell_weights->data.f,
- recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
- cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
- input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
- output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
- params, n_batch, n_cell, n_input, aux_input_size, n_output,
- activation_state->data.f, cell_state->data.f, input_gate_scratch,
- forget_gate_scratch, cell_scratch, output_gate_scratch,
- output_ptr_time);
- }
- } else {
- // Loop through the sequence backwards.
- for (int t = max_time - 1; t >= 0; t--) {
- const float* input_ptr = input->data.f + t * n_batch * n_input;
- float* output_ptr_time = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStepWithAuxInput(
- input_ptr, input_to_input_weights_ptr,
- input_to_forget_weights->data.f, input_to_cell_weights->data.f,
- input_to_output_weights->data.f, aux_input_ptr,
- aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
- aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
- recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
- recurrent_to_cell_weights->data.f,
- recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
- cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
- input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
- output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
- params, n_batch, n_cell, n_input, aux_input_size, n_output,
- activation_state->data.f, cell_state->data.f, input_gate_scratch,
- forget_gate_scratch, cell_scratch, output_gate_scratch,
- output_ptr_time);
- }
+ const int input_step = n_batch * n_input;
+ const int output_step = n_batch * n_output;
+ for (int t = 0; t < max_time; t++) {
+ // If this is the forward_sequence, step forward, otherwise step backwards.
+ const int t_rel = forward_sequence ? t : max_time - t - 1;
+ const float* input_ptr = input->data.f + t_rel * input_step;
+ float* output_ptr_time = output->data.f + t_rel * output_step;
+
+ kernel_utils::LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f,
+ input_to_cell_weights->data.f, input_to_output_weights->data.f,
+ aux_input_ptr, aux_input_to_input_weights_ptr,
+ aux_input_to_forget_weights_ptr, aux_input_to_cell_weights_ptr,
+ aux_input_to_output_weights_ptr, recurrent_to_input_weights_ptr,
+ recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f,
+ recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
+ cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
+ input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
+ output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
+ params, n_batch, n_cell, n_input, aux_input_size, n_output,
+ activation_state->data.f, cell_state->data.f, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch,
+ output_ptr_time);
}
return kTfLiteOk;
}
@@ -991,72 +969,41 @@ TfLiteStatus EvalHybrid(
aux_input_to_output_weights_scale =
aux_input_to_output_weights->params.scale;
}
- if (forward_sequence) {
- // Feed the sequence into the LSTM step-by-step.
- for (int t = 0; t < max_time; t++) {
- const float* input_ptr = input->data.f + t * n_batch * n_input;
- float* output_ptr = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStepWithAuxInput(
- input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
- input_to_forget_weights_ptr, input_to_forget_weights_scale,
- input_to_cell_weights_ptr, input_to_cell_weights_scale,
- input_to_output_weights_ptr, input_to_output_weights_scale,
- aux_input_ptr, aux_input_to_input_weights_ptr,
- aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
- aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
- aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
- aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
- recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
- recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
- recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
- recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
- cell_to_input_weights_scale, cell_to_forget_weights_ptr,
- cell_to_forget_weights_scale, cell_to_output_weights_ptr,
- cell_to_output_weights_scale, input_gate_bias_ptr,
- forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
- projection_weights_ptr, projection_weights_scale, projection_bias_ptr,
- params, n_batch, n_cell, n_input, aux_input_size, n_output,
- input_gate_scratch, forget_gate_scratch, cell_scratch,
- output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
- recovered_cell_weights_ptr, quantized_input_ptr,
- quantized_aux_input_ptr, quantized_output_state_ptr,
- quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
- output_ptr);
- }
- } else {
- // Loop through the sequence backwards.
- for (int t = max_time - 1; t >= 0; t--) {
- const float* input_ptr = input->data.f + t * n_batch * n_input;
- float* output_ptr = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStepWithAuxInput(
- input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
- input_to_forget_weights_ptr, input_to_forget_weights_scale,
- input_to_cell_weights_ptr, input_to_cell_weights_scale,
- input_to_output_weights_ptr, input_to_output_weights_scale,
- aux_input_ptr, aux_input_to_input_weights_ptr,
- aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
- aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
- aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
- aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
- recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
- recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
- recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
- recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
- cell_to_input_weights_scale, cell_to_forget_weights_ptr,
- cell_to_forget_weights_scale, cell_to_output_weights_ptr,
- cell_to_output_weights_scale, input_gate_bias_ptr,
- forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
- projection_weights_ptr, projection_weights_scale, projection_bias_ptr,
- params, n_batch, n_cell, n_input, aux_input_size, n_output,
- input_gate_scratch, forget_gate_scratch, cell_scratch,
- output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
- recovered_cell_weights_ptr, quantized_input_ptr,
- quantized_aux_input_ptr, quantized_output_state_ptr,
- quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
- output_ptr);
- }
+
+ // Feed the sequence into the LSTM step-by-step.
+ const int input_step = n_batch * n_input;
+ const int output_step = n_batch * n_output;
+ for (int t = 0; t < max_time; t++) {
+ // If this is the forward_sequence, step forward, otherwise step backwards.
+ const int t_rel = forward_sequence ? t : max_time - t - 1;
+ const float* input_ptr = input->data.f + t_rel * input_step;
+ float* output_ptr = output->data.f + t_rel * output_step;
+
+ kernel_utils::LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ aux_input_ptr, aux_input_to_input_weights_ptr,
+ aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
+ aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
+ aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
+ aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
+ recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
+ recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
+ recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
+ recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
+ cell_to_input_weights_scale, cell_to_forget_weights_ptr,
+ cell_to_forget_weights_scale, cell_to_output_weights_ptr,
+ cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
+ cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
+ projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
+ n_input, aux_input_size, n_output, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch,
+ scaling_factors_ptr, prod_scaling_factors_ptr,
+ recovered_cell_weights_ptr, quantized_input_ptr,
+ quantized_aux_input_ptr, quantized_output_state_ptr,
+ quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr);
}
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
index 4cd96348a2..f765235e04 100644
--- a/tensorflow/contrib/lite/kernels/comparisons.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons.cc
@@ -83,20 +83,24 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, \
&input2_multiplier, &input2_shift); \
\
+ ComparisonParams op_params; \
+ op_params.left_shift = left_shift; \
+ op_params.input1_offset = input1_offset; \
+ op_params.input1_multiplier = input1_multiplier; \
+ op_params.input1_shift = -input1_shift; \
+ op_params.input2_offset = input2_offset; \
+ op_params.input2_multiplier = input2_multiplier; \
+ op_params.input2_shift = -input2_shift; \
if (requires_broadcast) { \
- reference_ops::Broadcast##opname( \
- left_shift, GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
- input1_offset, input1_multiplier, input1_shift, \
- GetTensorData<uint8_t>(input2), GetTensorDims(input2), \
- input2_offset, input2_multiplier, input2_shift, \
- GetTensorData<bool>(output), GetTensorDims(output)); \
+ reference_ops::Broadcast4DSlow##opname##WithScaling( \
+ op_params, GetTensorShape(input1), GetTensorData<uint8_t>(input1), \
+ GetTensorShape(input2), GetTensorData<uint8_t>(input2), \
+ GetTensorShape(output), GetTensorData<bool>(output)); \
} else { \
- reference_ops::opname( \
- left_shift, GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
- input1_offset, input1_multiplier, input1_shift, \
- GetTensorData<uint8_t>(input2), GetTensorDims(input2), \
- input2_offset, input2_multiplier, input2_shift, \
- GetTensorData<bool>(output), GetTensorDims(output)); \
+ reference_ops::opname##WithScaling( \
+ op_params, GetTensorShape(input1), GetTensorData<uint8_t>(input1), \
+ GetTensorShape(input2), GetTensorData<uint8_t>(input2), \
+ GetTensorShape(output), GetTensorData<bool>(output)); \
} \
} \
}
@@ -108,16 +112,19 @@ TF_LITE_QUANTIZE_COMPARISON(Less);
TF_LITE_QUANTIZE_COMPARISON(LessEqual);
#undef TF_LITE_QUANTIZE_COMPARISON
-#define TF_LITE_COMPARISON(type, opname, requires_broadcast) \
- requires_broadcast \
- ? reference_ops::Broadcast##opname( \
- GetTensorData<type>(input1), GetTensorDims(input1), \
- GetTensorData<type>(input2), GetTensorDims(input2), \
- GetTensorData<bool>(output), GetTensorDims(output)) \
- : reference_ops::opname( \
- GetTensorData<type>(input1), GetTensorDims(input1), \
- GetTensorData<type>(input2), GetTensorDims(input2), \
- GetTensorData<bool>(output), GetTensorDims(output));
+#define TF_LITE_COMPARISON(type, opname, requires_broadcast) \
+ { \
+ ComparisonParams op_params; \
+ requires_broadcast \
+ ? reference_ops::Broadcast4DSlow##opname##NoScaling( \
+ op_params, GetTensorShape(input1), GetTensorData<type>(input1), \
+ GetTensorShape(input2), GetTensorData<type>(input2), \
+ GetTensorShape(output), GetTensorData<bool>(output)) \
+ : reference_ops::opname##NoScaling( \
+ op_params, GetTensorShape(input1), GetTensorData<type>(input1), \
+ GetTensorShape(input2), GetTensorData<type>(input2), \
+ GetTensorShape(output), GetTensorData<bool>(output)); \
+ }
TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc
index 25ea556d5a..7ad3399ffd 100644
--- a/tensorflow/contrib/lite/kernels/concatenation.cc
+++ b/tensorflow/contrib/lite/kernels/concatenation.cc
@@ -100,20 +100,31 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// allocate and populate these during Prepare().
// TODO(ycling): Activation function parameter is ignored. For now we dont have
// a model with a Concatenation with fused activation function.
-#define TF_LITE_CONCATENATION(type, scalar) \
- VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
- type::Concatenation<FusedActivationFunctionType::kNone, scalar>( \
- RemapDim(NumDimensions(output), axis), all_inputs.data(), \
- all_inputs.dims(), node->inputs->size, GetTensorData<scalar>(output), \
- GetTensorDims(output))
-
-#define TF_LITE_CONCATENATION_QUANTIZED(type) \
- VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
- type::Concatenation( \
- RemapDim(NumDimensions(output), axis), all_inputs.data(), \
- all_inputs.dims(), all_inputs.zero_point(), all_inputs.scale(), \
- node->inputs->size, GetTensorData<uint8>(output), GetTensorDims(output), \
- output->params.zero_point, output->params.scale)
+#define TF_LITE_CONCATENATION(type, scalar) \
+ { \
+ VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
+ tflite::ConcatenationParams op_params; \
+ op_params.axis = axis; \
+ op_params.inputs_count = node->inputs->size; \
+ type::Concatenation(op_params, all_inputs.shapes(), all_inputs.data(), \
+ GetTensorShape(output), \
+ GetTensorData<scalar>(output)); \
+ }
+
+#define TF_LITE_CONCATENATION_QUANTIZED(type) \
+ { \
+ VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
+ tflite::ConcatenationParams op_params; \
+ op_params.axis = axis; \
+ op_params.input_zeropoint = all_inputs.zero_point(); \
+ op_params.input_scale = all_inputs.scale(); \
+ op_params.inputs_count = node->inputs->size; \
+ op_params.output_zeropoint = output->params.zero_point; \
+ op_params.output_scale = output->params.scale; \
+ type::ConcatenationWithScaling(op_params, all_inputs.shapes(), \
+ all_inputs.data(), GetTensorShape(output), \
+ GetTensorData<uint8>(output)); \
+ }
switch (output->type) { // Already know in/outtypes are same.
case kTfLiteFloat32:
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index ab6bdaecaa..dbcadbee14 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -86,6 +86,18 @@ struct OpData {
bool run_multithreaded_kernel;
};
+inline PaddingType RuntimePaddingType(TfLitePadding padding) {
+ switch (padding) {
+ case TfLitePadding::kTfLitePaddingSame:
+ return PaddingType::kSame;
+ case TfLitePadding::kTfLitePaddingValid:
+ return PaddingType::kValid;
+ case TfLitePadding::kTfLitePaddingUnknown:
+ default:
+ return PaddingType::kNone;
+ }
+}
+
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to use as scratch space for im2col, and
@@ -414,35 +426,57 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
}
switch (effective_kernel_type) {
- case kReference:
+ case kReference: {
+ ConvParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = params->dilation_width_factor;
+ op_params.dilation_height_factor = params->dilation_height_factor;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = data->output_multiplier;
+ op_params.output_shift = -data->output_shift;
+ op_params.quantized_activation_min = data->output_activation_min;
+ op_params.quantized_activation_max = data->output_activation_max;
reference_ops::Conv(
- GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
- GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
- GetTensorData<int32_t>(bias), GetTensorDims(bias),
- params->stride_width, params->stride_height,
- params->dilation_width_factor, params->dilation_height_factor,
- data->padding.width, data->padding.height, output_offset,
- data->output_multiplier, data->output_shift,
- data->output_activation_min, data->output_activation_max,
- GetTensorData<uint8_t>(output), GetTensorDims(output),
- GetTensorData<uint8_t>(im2col), GetTensorDims(im2col), gemm_context);
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(filter), GetTensorData<uint8_t>(filter),
+ GetTensorShape(bias), GetTensorData<int32_t>(bias),
+ GetTensorShape(output), GetTensorData<uint8_t>(output),
+ GetTensorShape(im2col), GetTensorData<uint8_t>(im2col), gemm_context);
break;
+ }
case kGenericOptimized:
case kMultithreadOptimized:
- case kCblasOptimized:
+ case kCblasOptimized: {
// There is only one optimized implementation for Quantized Conv.
+ ConvParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = params->dilation_width_factor;
+ op_params.dilation_height_factor = params->dilation_height_factor;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = data->output_multiplier;
+ op_params.output_shift = -data->output_shift;
+ op_params.quantized_activation_min = data->output_activation_min;
+ op_params.quantized_activation_max = data->output_activation_max;
optimized_ops::Conv(
- GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
- GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
- GetTensorData<int32_t>(bias), GetTensorDims(bias),
- params->stride_width, params->stride_height,
- params->dilation_width_factor, params->dilation_height_factor,
- data->padding.width, data->padding.height, output_offset,
- data->output_multiplier, data->output_shift,
- data->output_activation_min, data->output_activation_max,
- GetTensorData<uint8_t>(output), GetTensorDims(output),
- GetTensorData<uint8_t>(im2col), GetTensorDims(im2col), gemm_context);
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(filter), GetTensorData<uint8_t>(filter),
+ GetTensorShape(bias), GetTensorData<int32_t>(bias),
+ GetTensorShape(output), GetTensorData<uint8_t>(output),
+ GetTensorShape(im2col), GetTensorData<uint8_t>(im2col), gemm_context);
break;
+ }
}
}
@@ -465,29 +499,33 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
} else {
effective_kernel_type = kernel_type;
}
+ ConvParams op_params;
+ op_params.padding_type = RuntimePaddingType(params->padding);
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = params->dilation_width_factor;
+ op_params.dilation_height_factor = params->dilation_height_factor;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
switch (effective_kernel_type) {
case kReference: {
- reference_ops::Conv(
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(filter), GetTensorDims(filter),
- GetTensorData<float>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, params->dilation_width_factor,
- params->dilation_height_factor, data->padding.width,
- data->padding.height, output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output),
- GetTensorData<float>(im2col), GetTensorDims(im2col));
+ reference_ops::Conv(op_params, GetTensorShape(input),
+ GetTensorData<float>(input), GetTensorShape(filter),
+ GetTensorData<float>(filter), GetTensorShape(bias),
+ GetTensorData<float>(bias), GetTensorShape(output),
+ GetTensorData<float>(output), GetTensorShape(im2col),
+ GetTensorData<float>(im2col));
break;
}
case kGenericOptimized: {
- optimized_ops::Conv(
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(filter), GetTensorDims(filter),
- GetTensorData<float>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, params->dilation_width_factor,
- params->dilation_height_factor, data->padding.width,
- data->padding.height, output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output),
- GetTensorData<float>(im2col), GetTensorDims(im2col));
+ optimized_ops::Conv(op_params, GetTensorShape(input),
+ GetTensorData<float>(input), GetTensorShape(filter),
+ GetTensorData<float>(filter), GetTensorShape(bias),
+ GetTensorData<float>(bias), GetTensorShape(output),
+ GetTensorData<float>(output), GetTensorShape(im2col),
+ GetTensorData<float>(im2col));
break;
}
case kMultithreadOptimized: {
@@ -498,25 +536,21 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
filter_data = GetTensorData<float>(filter);
}
multithreaded_ops::Conv(
- *eigen_support::GetThreadPoolDevice(context),
- GetTensorData<float>(input), GetTensorDims(input), filter_data,
- GetTensorDims(filter), GetTensorData<float>(bias),
- GetTensorDims(bias), params->stride_width, params->stride_height,
- data->padding.width, data->padding.height, params->padding,
- output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output),
- GetTensorData<float>(im2col), GetTensorDims(im2col));
+ *eigen_support::GetThreadPoolDevice(context), op_params,
+ GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(filter), filter_data, GetTensorShape(bias),
+ GetTensorData<float>(bias), GetTensorShape(output),
+ GetTensorData<float>(output), GetTensorShape(im2col),
+ GetTensorData<float>(im2col));
break;
}
case kCblasOptimized: {
- cblas_ops::Conv(GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(filter), GetTensorDims(filter),
- GetTensorData<float>(bias), GetTensorDims(bias),
- params->stride_width, params->stride_height,
- data->padding.width, data->padding.height,
- output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output),
- GetTensorData<float>(im2col), GetTensorDims(im2col));
+ cblas_ops::Conv(op_params, GetTensorShape(input),
+ GetTensorData<float>(input), GetTensorShape(filter),
+ GetTensorData<float>(filter), GetTensorShape(bias),
+ GetTensorData<float>(bias), GetTensorShape(output),
+ GetTensorData<float>(output), GetTensorShape(im2col),
+ GetTensorData<float>(im2col));
break;
}
}
@@ -561,18 +595,27 @@ void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
case kReference:
case kGenericOptimized:
case kMultithreadOptimized:
- case kCblasOptimized:
+ case kCblasOptimized: {
// There is only one implementation for hybrid kernel. Note
// this does not make use of gemmlowp nor supports multithreading.
+ ConvParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = 1;
+ op_params.dilation_height_factor = 1;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
optimized_ops::HybridConv(
- quantized_input_ptr_batch, GetTensorDims(input), filter_ptr,
- GetTensorDims(filter), GetTensorData<float>(bias),
- GetTensorDims(bias), params->stride_width, params->stride_height,
- data->padding.width, data->padding.height, scaling_factors_ptr,
- output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output), im2col_ptr,
- GetTensorDims(im2col));
+ op_params, scaling_factors_ptr, GetTensorShape(input),
+ quantized_input_ptr_batch, GetTensorShape(filter), filter_ptr,
+ GetTensorShape(bias), GetTensorData<float>(bias),
+ GetTensorShape(output), GetTensorData<float>(output),
+ GetTensorShape(im2col), im2col_ptr);
break;
+ }
}
}
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
index 3e1ce60113..19958844a1 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
@@ -180,34 +180,31 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
- void (*depthwise_conv)(const float*, const Dims<4>&, const float*,
- const Dims<4>&, const float*, const Dims<4>&, int, int,
- int, int, int, int, int, float, float, float*,
- const Dims<4>&);
- KernelType effective_kernel_type;
- // TODO(suharshs): Currently only the reference implementation supports
- // dilations.
- if ((params->dilation_width_factor != 1) ||
- (params->dilation_height_factor != 1)) {
- effective_kernel_type = kReference;
- } else {
- effective_kernel_type = kernel_type;
- }
-
- if (effective_kernel_type == kReference) {
+ void (*depthwise_conv)(const DepthwiseParams&, const RuntimeShape&,
+ const float*, const RuntimeShape&, const float*,
+ const RuntimeShape&, const float*, const RuntimeShape&,
+ float*);
+ if (kernel_type == kReference) {
depthwise_conv = &reference_ops::DepthwiseConv;
} else {
depthwise_conv = &optimized_ops::DepthwiseConv;
}
- depthwise_conv(
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(filter), GetTensorDims(filter),
- GetTensorData<float>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, params->dilation_width_factor,
- params->dilation_height_factor, data->padding.width, data->padding.height,
- params->depth_multiplier, output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output));
+ DepthwiseParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = params->dilation_width_factor;
+ op_params.dilation_height_factor = params->dilation_height_factor;
+ op_params.depth_multiplier = params->depth_multiplier;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ depthwise_conv(op_params, GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(filter), GetTensorData<float>(filter),
+ GetTensorShape(bias), GetTensorData<float>(bias),
+ GetTensorShape(output), GetTensorData<float>(output));
}
template <KernelType kernel_type>
@@ -219,37 +216,38 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
auto filter_offset = -filter->params.zero_point;
auto output_offset = output->params.zero_point;
- void (*depthwise_conv)(const uint8*, const Dims<4>&, int32, const uint8*,
- const Dims<4>&, int32, const int32*, const Dims<4>&,
- int, int, int, int, int, int, int, int32, int32, int,
- int32, int32, uint8*, const Dims<4>&);
-
- KernelType effective_kernel_type;
- // TODO(suharshs): Currently only the reference implementation supports
- // dilations.
- if ((params->dilation_width_factor != 1) ||
- (params->dilation_height_factor != 1)) {
- effective_kernel_type = kReference;
- } else {
- effective_kernel_type = kernel_type;
- }
+ void (*depthwise_conv)(const DepthwiseParams&, const RuntimeShape&,
+ const uint8*, const RuntimeShape&, const uint8*,
+ const RuntimeShape&, const int32*, const RuntimeShape&,
+ uint8*);
- if (effective_kernel_type == kReference) {
+ if (kernel_type == kReference) {
depthwise_conv = &reference_ops::DepthwiseConv;
} else {
depthwise_conv = &optimized_ops::DepthwiseConv;
}
- depthwise_conv(
- GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
- GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
- GetTensorData<int32_t>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, params->dilation_width_factor,
- params->dilation_height_factor, data->padding.width, data->padding.height,
- params->depth_multiplier, output_offset, data->output_multiplier,
- data->output_shift, data->output_activation_min,
- data->output_activation_max, GetTensorData<uint8_t>(output),
- GetTensorDims(output));
+ DepthwiseParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = params->dilation_width_factor;
+ op_params.dilation_height_factor = params->dilation_height_factor;
+ op_params.depth_multiplier = params->depth_multiplier;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = data->output_multiplier;
+ op_params.output_shift = -data->output_shift;
+ op_params.quantized_activation_min = data->output_activation_min;
+ op_params.quantized_activation_max = data->output_activation_max;
+ depthwise_conv(op_params, GetTensorShape(input),
+ GetTensorData<uint8_t>(input), GetTensorShape(filter),
+ GetTensorData<uint8_t>(filter), GetTensorShape(bias),
+ GetTensorData<int32_t>(bias), GetTensorShape(output),
+ GetTensorData<uint8_t>(output));
}
template <KernelType kernel_type>
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
index 2af26ab80a..4a33a0319d 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
@@ -14,12 +14,24 @@ limitations under the License.
==============================================================================*/
#include <cstdarg>
#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
#include "tensorflow/contrib/lite/model.h"
namespace tflite {
+
+namespace ops {
+namespace builtin {
+
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_REF();
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT();
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_NEON_OPT();
+
+} // namespace builtin
+} // namespace ops
+
namespace {
using ::testing::ElementsAreArray;
@@ -28,9 +40,11 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
public:
// TODO(ahentz): Also test different activation types, bias, padding types,
// stride values.
- BaseDepthwiseConvolutionOpModel(const TensorData& input,
+ BaseDepthwiseConvolutionOpModel(TfLiteRegistration* registration,
+ const TensorData& input,
const TensorData& filter,
const TensorData& output,
+ Padding padding_type,
int dilation_factor = 1) {
input_ = AddInput(input);
filter_ = AddInput(filter);
@@ -56,11 +70,14 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
SetBuiltinOp(
BuiltinOperator_DEPTHWISE_CONV_2D,
BuiltinOptions_DepthwiseConv2DOptions,
- CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul,
+ CreateDepthwiseConv2DOptions(builder_, padding_type, 1, 1, depth_mul,
ActivationFunctionType_NONE,
dilation_factor, dilation_factor)
.Union());
+ resolver_ = absl::make_unique<SingleOpResolver>(
+ BuiltinOperator_DEPTHWISE_CONV_2D, registration);
+
BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
}
@@ -86,10 +103,25 @@ class DepthwiseConvolutionOpModel : public BaseDepthwiseConvolutionOpModel {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};
-TEST(DepthwiseConvolutionOpTest, SimpleTest) {
- DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}},
+const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
+ {"Reference", ops::builtin::Register_DEPTHWISE_CONVOLUTION_REF()},
+ {"GenericOptimized",
+ ops::builtin::Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT()},
+ {"NeonOptimized", ops::builtin::Register_DEPTHWISE_CONVOLUTION_NEON_OPT()},
+});
+
+class DepthwiseConvolutionOpTest : public SingleOpTest {
+ protected:
+ const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+ return *kKernelMap;
+ }
+};
+
+TEST_P(DepthwiseConvolutionOpTest, SimpleTest) {
+ DepthwiseConvolutionOpModel m(GetRegistration(),
+ {TensorType_FLOAT32, {1, 3, 2, 2}},
{TensorType_FLOAT32, {1, 2, 2, 4}},
- {TensorType_FLOAT32, {}});
+ {TensorType_FLOAT32, {}}, Padding_VALID);
m.SetInput({
1, 2, 7, 8, // column 1
@@ -112,7 +144,7 @@ TEST(DepthwiseConvolutionOpTest, SimpleTest) {
}));
}
-TEST(DepthwiseConvolutionOpTest, SimpleDilatedTest) {
+TEST_P(DepthwiseConvolutionOpTest, SimpleDilatedTestPaddingValid) {
const int depth = 1;
const int image_width = 9;
const int image_height = 9;
@@ -121,10 +153,11 @@ TEST(DepthwiseConvolutionOpTest, SimpleDilatedTest) {
const int filter_count = 1;
const int dilation_factor = 3;
DepthwiseConvolutionOpModel m(
+ GetRegistration(),
{TensorType_FLOAT32,
{image_batch_count, image_height, image_width, depth}},
{TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
- {TensorType_FLOAT32, {}}, dilation_factor);
+ {TensorType_FLOAT32, {}}, Padding_VALID, dilation_factor);
// The image matrix is:
// | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
@@ -164,6 +197,41 @@ TEST(DepthwiseConvolutionOpTest, SimpleDilatedTest) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
}
+TEST_P(DepthwiseConvolutionOpTest, SimpleDilatedTestPaddingSame) {
+ const int depth = 1;
+ const int image_width = 3;
+ const int image_height = 3;
+ const int image_batch_count = 1;
+ const int filter_size = 2;
+ const int filter_count = 1;
+ const int dilation_factor = 2;
+ DepthwiseConvolutionOpModel m(
+ GetRegistration(),
+ {TensorType_FLOAT32,
+ {image_batch_count, image_height, image_width, depth}},
+ {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+ {TensorType_FLOAT32, {}}, Padding_SAME, dilation_factor);
+
+ // The image matrix is:
+ // | 1 | 1 | 1 |
+ // | 1 | 1 | 1 |
+ // | 1 | 1 | 1 |
+ m.SetInput({1, 1, 1, 1, 1, 1, 1, 1, 1});
+ // The filter matrix is:
+ // | 1 | 2 |
+ // | 3 | 4 |
+ m.SetFilter({1, 2, 3, 4});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Output:
+ // | 4 | 7 | 3 |
+ // | 6 |10 | 4 |
+ // | 2 | 3 | 1 |
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 7, 3, 6, 10, 4, 2, 3, 1}));
+}
+
class QuantizedDepthwiseConvolutionOpModel
: public BaseDepthwiseConvolutionOpModel {
public:
@@ -188,13 +256,20 @@ class QuantizedDepthwiseConvolutionOpModel
}
};
+class QuantizedDepthwiseConvolutionOpTest : public SingleOpTest {
+ protected:
+ const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+ return *kKernelMap;
+ }
+};
+
// In this test we set the input and output scales so that the results match
// exactly the 'non-quantized' version.
-TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
+TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
QuantizedDepthwiseConvolutionOpModel m(
- {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
+ GetRegistration(), {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
{TensorType_UINT8, {1, 2, 2, 4}, -63.5, 64},
- {TensorType_UINT8, {}, -127, 128});
+ {TensorType_UINT8, {}, -127, 128}, Padding_VALID);
m.SetInput({
1, 2, 7, 8, // column 1
@@ -224,15 +299,16 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
}));
}
-TEST(QuantizedDepthwiseConvolutionOpTest,
- SimpleTestQuantizedFilterMultiplierGreaterThan1) {
+TEST_P(QuantizedDepthwiseConvolutionOpTest,
+ SimpleTestQuantizedFilterMultiplierGreaterThan1) {
QuantizedDepthwiseConvolutionOpModel quant_op(
- {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
+ GetRegistration(), {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
{TensorType_UINT8, {1, 2, 2, 4}, -128.5, 128},
- {TensorType_UINT8, {}, -127, 128});
- DepthwiseConvolutionOpModel float_op({TensorType_FLOAT32, {1, 3, 2, 2}},
+ {TensorType_UINT8, {}, -127, 128}, Padding_VALID);
+ DepthwiseConvolutionOpModel float_op(GetRegistration(),
+ {TensorType_FLOAT32, {1, 3, 2, 2}},
{TensorType_FLOAT32, {1, 2, 2, 4}},
- {TensorType_FLOAT32, {}});
+ {TensorType_FLOAT32, {}}, Padding_VALID);
std::initializer_list<float> input = {
1, 2, 7, 8, // column 1
@@ -261,7 +337,7 @@ TEST(QuantizedDepthwiseConvolutionOpTest,
ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1)));
}
-TEST(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTest) {
+TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTestPaddingValid) {
const int depth = 1;
const int image_width = 9;
const int image_height = 9;
@@ -270,6 +346,7 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTest) {
const int filter_count = 1;
const int dilation_factor = 3;
QuantizedDepthwiseConvolutionOpModel m(
+ GetRegistration(),
{TensorType_UINT8,
{image_batch_count, image_height, image_width, depth},
0,
@@ -278,7 +355,7 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTest) {
{depth, filter_size, filter_size, filter_count},
0,
255},
- {TensorType_UINT8, {}, 0, 255}, dilation_factor);
+ {TensorType_UINT8, {}, 0, 255}, Padding_VALID, dilation_factor);
// The image matrix is:
// | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
@@ -319,6 +396,55 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTest) {
ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
}
+TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTestPaddingSame) {
+ const int depth = 1;
+ const int image_width = 3;
+ const int image_height = 3;
+ const int image_batch_count = 1;
+ const int filter_size = 2;
+ const int filter_count = 1;
+ const int dilation_factor = 2;
+ QuantizedDepthwiseConvolutionOpModel m(
+ GetRegistration(),
+ {TensorType_UINT8,
+ {image_batch_count, image_height, image_width, depth},
+ 0,
+ 255},
+ {TensorType_UINT8,
+ {depth, filter_size, filter_size, filter_count},
+ 0,
+ 255},
+ {TensorType_UINT8, {}, 0, 255}, Padding_SAME, dilation_factor);
+
+ // The image matrix is:
+ // | 1 | 1 | 1 |
+ // | 1 | 1 | 1 |
+ // | 1 | 1 | 1 |
+ m.SetInput({1, 1, 1, 1, 1, 1, 1, 1, 1});
+ // The filter matrix is:
+ // | 1 | 2 |
+ // | 3 | 4 |
+ m.SetFilter({1, 2, 3, 4});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Output:
+ // | 4 | 7 | 3 |
+ // | 6 |10 | 4 |
+ // | 2 | 3 | 1 |
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray({4, 7, 3, 6, 10, 4, 2, 3, 1}));
+}
+
+INSTANTIATE_TEST_CASE_P(
+ DepthwiseConvolutionOpTest, DepthwiseConvolutionOpTest,
+ ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
+
+INSTANTIATE_TEST_CASE_P(
+ QuantizedDepthwiseConvolutionOpTest, QuantizedDepthwiseConvolutionOpTest,
+ ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/dequantize.cc b/tensorflow/contrib/lite/kernels/dequantize.cc
index 3a08f48b00..59bf64e0af 100644
--- a/tensorflow/contrib/lite/kernels/dequantize.cc
+++ b/tensorflow/contrib/lite/kernels/dequantize.cc
@@ -77,13 +77,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
- auto zero_point = op_context.input->params.zero_point;
- auto scale = op_context.input->params.scale;
-
- optimized_ops::Dequantize(GetTensorData<uint8_t>(op_context.input),
- GetTensorDims(op_context.input), zero_point, scale,
- GetTensorData<float>(op_context.output),
- GetTensorDims(op_context.output));
+ tflite::DequantizationParams op_params;
+ op_params.zero_point = op_context.input->params.zero_point;
+ op_params.scale = op_context.input->params.scale;
+ optimized_ops::Dequantize(op_params, GetTensorShape(op_context.input),
+ GetTensorData<uint8_t>(op_context.input),
+ GetTensorShape(op_context.output),
+ GetTensorData<float>(op_context.output));
if (IsConstantTensor(op_context.input)) {
op_data->float_dequantized_weights_initialized = true;
diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess.cc b/tensorflow/contrib/lite/kernels/detection_postprocess.cc
index d2906632d7..e21dc5ced9 100644
--- a/tensorflow/contrib/lite/kernels/detection_postprocess.cc
+++ b/tensorflow/contrib/lite/kernels/detection_postprocess.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include <string.h>
#include <numeric>
#include <vector>
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
index 94c91a6bd6..1e8caebd82 100644
--- a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
+++ b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc
index 7945c095b1..8d4bb51006 100644
--- a/tensorflow/contrib/lite/kernels/div.cc
+++ b/tensorflow/contrib/lite/kernels/div.cc
@@ -81,24 +81,27 @@ template <KernelType kernel_type>
void EvalDiv(TfLiteContext* context, TfLiteNode* node, TfLiteDivParams* params,
const OpData* data, const TfLiteTensor* input1,
const TfLiteTensor* input2, TfLiteTensor* output) {
-#define TF_LITE_DIV(type, opname, data_type) \
- data_type output_activation_min, output_activation_max; \
- CalculateActivationRange(params->activation, &output_activation_min, \
- &output_activation_max); \
- type::opname(GetTensorData<data_type>(input1), GetTensorDims(input1), \
- GetTensorData<data_type>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<data_type>(output), GetTensorDims(output))
+#define TF_LITE_DIV(type, opname, data_type) \
+ tflite::ArithmeticParams op_params; \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ SetActivationParams(output_activation_min, output_activation_max, \
+ &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<data_type>(input1), GetTensorShape(input2), \
+ GetTensorData<data_type>(input2), GetTensorShape(output), \
+ GetTensorData<data_type>(output))
if (output->type == kTfLiteInt32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
- TF_LITE_DIV(reference_ops, BroadcastDiv, int32_t);
+ TF_LITE_DIV(reference_ops, BroadcastDiv4DSlow, int32_t);
} else {
TF_LITE_DIV(reference_ops, Div, int32_t);
}
} else {
if (data->requires_broadcast) {
- TF_LITE_DIV(optimized_ops, BroadcastDiv, int32_t);
+ TF_LITE_DIV(optimized_ops, BroadcastDiv4DSlow, int32_t);
} else {
TF_LITE_DIV(optimized_ops, Div, int32_t);
}
@@ -106,13 +109,13 @@ void EvalDiv(TfLiteContext* context, TfLiteNode* node, TfLiteDivParams* params,
} else if (output->type == kTfLiteFloat32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
- TF_LITE_DIV(reference_ops, BroadcastDiv, float);
+ TF_LITE_DIV(reference_ops, BroadcastDiv4DSlow, float);
} else {
TF_LITE_DIV(reference_ops, Div, float);
}
} else {
if (data->requires_broadcast) {
- TF_LITE_DIV(optimized_ops, BroadcastDiv, float);
+ TF_LITE_DIV(optimized_ops, BroadcastDiv4DSlow, float);
} else {
TF_LITE_DIV(optimized_ops, Div, float);
}
diff --git a/tensorflow/contrib/lite/kernels/fake_quant.cc b/tensorflow/contrib/lite/kernels/fake_quant.cc
index f9bc3747cb..b51af72fe6 100644
--- a/tensorflow/contrib/lite/kernels/fake_quant.cc
+++ b/tensorflow/contrib/lite/kernels/fake_quant.cc
@@ -68,11 +68,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params =
reinterpret_cast<TfLiteFakeQuantParams*>(node->builtin_data);
- reference_ops::FakeQuant(GetTensorData<float>(op_context.input),
- GetTensorDims(op_context.input), params->min,
- params->max, params->num_bits,
- GetTensorData<float>(op_context.output),
- GetTensorDims(op_context.output));
+ tflite::FakeQuantParams op_params;
+ op_params.num_bits = params->num_bits;
+ op_params.minmax.min = params->min;
+ op_params.minmax.max = params->max;
+ reference_ops::FakeQuant(op_params, GetTensorShape(op_context.input),
+ GetTensorData<float>(op_context.input),
+ GetTensorShape(op_context.output),
+ GetTensorData<float>(op_context.output));
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc
index 7a71fcc219..f6d2f76dbe 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected.cc
@@ -281,15 +281,23 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
int32_t input_offset = -input->params.zero_point;
int32_t filter_offset = -filter->params.zero_point;
int32_t output_offset = output->params.zero_point;
-#define TF_LITE_FULLY_CONNECTED(type, output_data_type) \
- type::FullyConnected( \
- GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset, \
- GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset, \
- GetTensorData<int32_t>(bias), GetTensorDims(bias), output_offset, \
- data->output_multiplier, data->output_shift, \
- data->output_activation_min, data->output_activation_max, \
- GetTensorData<output_data_type>(output), GetTensorDims(output), \
- gemm_context)
+#define TF_LITE_FULLY_CONNECTED(type, output_data_type) \
+ { \
+ FullyConnectedParams op_params; \
+ op_params.input_offset = input_offset; \
+ op_params.weights_offset = filter_offset; \
+ op_params.output_offset = output_offset; \
+ op_params.output_multiplier = data->output_multiplier; \
+ op_params.output_shift = -data->output_shift; \
+ op_params.quantized_activation_min = data->output_activation_min; \
+ op_params.quantized_activation_max = data->output_activation_max; \
+ type::FullyConnected( \
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
+ GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
+ GetTensorShape(bias), GetTensorData<int32_t>(bias), \
+ GetTensorShape(output), GetTensorData<output_data_type>(output), \
+ gemm_context); \
+ }
if (kernel_type == kReference) {
switch (output->type) {
case kTfLiteUInt8:
@@ -349,15 +357,20 @@ TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
return kTfLiteError;
}
-#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \
- type::ShuffledFullyConnected( \
- GetTensorData<uint8_t>(input), GetTensorDims(input), \
- GetTensorData<uint8_t>(filter), GetTensorDims(filter), \
- GetTensorData<int32_t>(bias), GetTensorDims(bias), \
- data->output_multiplier, data->output_shift, \
- data->output_activation_min, data->output_activation_max, \
- GetTensorData<int16_t>(output), GetTensorDims(output), \
- GetTensorData<uint8_t>(shuffled_input_workspace), gemm_context)
+#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \
+ { \
+ FullyConnectedParams op_params; \
+ op_params.output_multiplier = data->output_multiplier; \
+ op_params.output_shift = -data->output_shift; \
+ op_params.quantized_activation_min = data->output_activation_min; \
+ op_params.quantized_activation_max = data->output_activation_max; \
+ type::ShuffledFullyConnected( \
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
+ GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
+ GetTensorShape(bias), GetTensorData<int32_t>(bias), \
+ GetTensorShape(output), GetTensorData<int16_t>(output), \
+ GetTensorData<uint8_t>(shuffled_input_workspace), gemm_context); \
+ }
if (kernel_type == kReference) {
TF_LITE_SHUFFLED_FULLY_CONNECTED(reference_ops);
} else {
@@ -376,12 +389,17 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
-#define TF_LITE_FULLY_CONNECTED(type) \
- type::FullyConnected(GetTensorData<float>(input), GetTensorDims(input), \
- GetTensorData<float>(filter), GetTensorDims(filter), \
- GetTensorData<float>(bias), GetTensorDims(bias), \
- output_activation_min, output_activation_max, \
- GetTensorData<float>(output), GetTensorDims(output))
+#define TF_LITE_FULLY_CONNECTED(type) \
+ { \
+ FullyConnectedParams op_params; \
+ op_params.float_activation_min = output_activation_min; \
+ op_params.float_activation_max = output_activation_max; \
+ type::FullyConnected(op_params, GetTensorShape(input), \
+ GetTensorData<float>(input), GetTensorShape(filter), \
+ GetTensorData<float>(filter), GetTensorShape(bias), \
+ GetTensorData<float>(bias), GetTensorShape(output), \
+ GetTensorData<float>(output)); \
+ }
if (kernel_type == kReference) {
TF_LITE_FULLY_CONNECTED(reference_ops);
} else if (kernel_type == kPie) {
diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc
index badd2de11a..b5afeb1a7b 100644
--- a/tensorflow/contrib/lite/kernels/gather.cc
+++ b/tensorflow/contrib/lite/kernels/gather.cc
@@ -84,11 +84,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* positions = GetInput(context, node, kInputPositions);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const int input_rank = NumDimensions(input);
-#define TF_LITE_GATHER(data_type, index_type) \
- optimized_ops::Gather( \
- GetTensorData<data_type>(input), GetTensorDims(input), input_rank, \
- GetTensorData<index_type>(positions), GetTensorDims(positions), \
- GetTensorData<data_type>(output), GetTensorDims(output));
+#define TF_LITE_GATHER(data_type, index_type) \
+ { \
+ tflite::GatherParams op_params; \
+ op_params.input_rank = input_rank; \
+ optimized_ops::Gather( \
+ op_params, GetTensorShape(input), GetTensorData<data_type>(input), \
+ GetTensorShape(positions), GetTensorData<index_type>(positions), \
+ GetTensorShape(output), GetTensorData<data_type>(output)); \
+ }
switch (input->type) {
case kTfLiteFloat32:
TF_LITE_GATHER(float, int32_t);
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index a6fd4ac2dd..afb5ec05df 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -43,6 +43,10 @@ cc_library(
"compatibility.h",
"types.h",
],
+ deps = [
+ "//tensorflow/contrib/lite/kernels:op_macros",
+ "@com_google_absl//absl/base:core_headers",
+ ],
)
config_setting(
@@ -259,6 +263,7 @@ cc_library(
deps = [
":round",
":types",
+ "//tensorflow/contrib/lite/kernels:op_macros",
],
)
@@ -290,7 +295,9 @@ cc_library(
"common.h",
"reference/depthwiseconv_float.h",
"reference/depthwiseconv_uint8.h",
+ "reference/fully_connected.h",
"reference/reference_ops.h",
+ "reference/softmax.h",
],
deps = [
":quantization_util",
@@ -299,6 +306,7 @@ cc_library(
":types",
"@gemmlowp",
"//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/kernels:op_macros",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -319,8 +327,10 @@ cc_library(
"common.h",
"reference/depthwiseconv_float.h",
"reference/depthwiseconv_uint8.h",
+ "reference/fully_connected.h",
"reference/legacy_reference_ops.h",
"reference/reference_ops.h",
+ "reference/softmax.h",
],
deps = [
":quantization_util",
@@ -329,6 +339,7 @@ cc_library(
":types",
"@gemmlowp",
"//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/kernels:op_macros",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -458,9 +469,10 @@ cc_library(
],
copts = NEON_FLAGS_IF_APPLICABLE,
deps = [
- "//tensorflow/contrib/lite/kernels:activation_functor",
+ "@com_google_absl//absl/base:core_headers",
"//tensorflow/contrib/lite/c:c_api_internal",
"@arm_neon_2_x86_sse",
+ "//tensorflow/contrib/lite/kernels:op_macros",
"@gemmlowp",
] + select({
":arm": [
diff --git a/tensorflow/contrib/lite/kernels/internal/compatibility.h b/tensorflow/contrib/lite/kernels/internal/compatibility.h
index 93fc6b6a76..b87cf2b60d 100644
--- a/tensorflow/contrib/lite/kernels/internal/compatibility.h
+++ b/tensorflow/contrib/lite/kernels/internal/compatibility.h
@@ -15,65 +15,65 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
-#include <cassert>
#include <cstdint>
-#include <cstdlib>
+
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
#ifndef TFLITE_DCHECK
-#define TFLITE_DCHECK(condition) (condition) ? (void)0 : assert(false)
+#define TFLITE_DCHECK(condition) (condition) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
#ifndef TFLITE_DCHECK_EQ
-#define TFLITE_DCHECK_EQ(x, y) ((x) == (y)) ? (void)0 : assert(false)
+#define TFLITE_DCHECK_EQ(x, y) ((x) == (y)) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
#ifndef TFLITE_DCHECK_NE
-#define TFLITE_DCHECK_NE(x, y) ((x) != (y)) ? (void)0 : assert(false)
+#define TFLITE_DCHECK_NE(x, y) ((x) != (y)) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
#ifndef TFLITE_DCHECK_GE
-#define TFLITE_DCHECK_GE(x, y) ((x) >= (y)) ? (void)0 : assert(false)
+#define TFLITE_DCHECK_GE(x, y) ((x) >= (y)) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
#ifndef TFLITE_DCHECK_GT
-#define TFLITE_DCHECK_GT(x, y) ((x) > (y)) ? (void)0 : assert(false)
+#define TFLITE_DCHECK_GT(x, y) ((x) > (y)) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
#ifndef TFLITE_DCHECK_LE
-#define TFLITE_DCHECK_LE(x, y) ((x) <= (y)) ? (void)0 : assert(false)
+#define TFLITE_DCHECK_LE(x, y) ((x) <= (y)) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
#ifndef TFLITE_DCHECK_LT
-#define TFLITE_DCHECK_LT(x, y) ((x) < (y)) ? (void)0 : assert(false)
+#define TFLITE_DCHECK_LT(x, y) ((x) < (y)) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
// TODO(ahentz): Clean up: We should stick to the DCHECK versions.
#ifndef TFLITE_CHECK
-#define TFLITE_CHECK(condition) (condition) ? (void)0 : abort()
+#define TFLITE_CHECK(condition) (condition) ? (void)0 : TFLITE_ABORT
#endif
#ifndef TFLITE_CHECK_EQ
-#define TFLITE_CHECK_EQ(x, y) ((x) == (y)) ? (void)0 : abort()
+#define TFLITE_CHECK_EQ(x, y) ((x) == (y)) ? (void)0 : TFLITE_ABORT
#endif
#ifndef TFLITE_CHECK_NE
-#define TFLITE_CHECK_NE(x, y) ((x) != (y)) ? (void)0 : abort()
+#define TFLITE_CHECK_NE(x, y) ((x) != (y)) ? (void)0 : TFLITE_ABORT
#endif
#ifndef TFLITE_CHECK_GE
-#define TFLITE_CHECK_GE(x, y) ((x) >= (y)) ? (void)0 : abort()
+#define TFLITE_CHECK_GE(x, y) ((x) >= (y)) ? (void)0 : TFLITE_ABORT
#endif
#ifndef TFLITE_CHECK_GT
-#define TFLITE_CHECK_GT(x, y) ((x) > (y)) ? (void)0 : abort()
+#define TFLITE_CHECK_GT(x, y) ((x) > (y)) ? (void)0 : TFLITE_ABORT
#endif
#ifndef TFLITE_CHECK_LE
-#define TFLITE_CHECK_LE(x, y) ((x) <= (y)) ? (void)0 : abort()
+#define TFLITE_CHECK_LE(x, y) ((x) <= (y)) ? (void)0 : TFLITE_ABORT
#endif
#ifndef TFLITE_CHECK_LT
-#define TFLITE_CHECK_LT(x, y) ((x) < (y)) ? (void)0 : abort()
+#define TFLITE_CHECK_LT(x, y) ((x) < (y)) ? (void)0 : TFLITE_ABORT
#endif
// TODO(ahentz): Clean up.
diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc
index 844ee6a53d..41862a21a6 100644
--- a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/test_util.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
@@ -28,23 +29,21 @@ namespace tflite {
namespace {
// Runs the DepthwiseConv and compares against the reference implementation.
-template <FusedActivationFunctionType Ac>
-void TestOneDepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride, int pad_width, int pad_height,
- int depth_multiplier, const Dims<4>& output_dims) {
- const int output_buffer_size = RequiredBufferSizeForDims(output_dims);
+void TestOneDepthwiseConv(
+ const DepthwiseParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape) {
+ const int output_buffer_size = output_shape.FlatSize();
std::vector<float> output_data(output_buffer_size);
std::vector<float> reference_output_data(output_buffer_size);
- reference_ops::DepthwiseConv<Ac>(input_data, input_dims, filter_data,
- filter_dims, bias_data, bias_dims, stride,
- pad_width, pad_height, depth_multiplier,
- reference_output_data.data(), output_dims);
- optimized_ops::DepthwiseConv<Ac>(input_data, input_dims, filter_data,
- filter_dims, bias_data, bias_dims, stride,
- pad_width, pad_height, depth_multiplier,
- output_data.data(), output_dims);
+ reference_ops::DepthwiseConv(params, input_shape, input_data, filter_shape,
+ filter_data, bias_shape, bias_data, output_shape,
+ reference_output_data.data());
+ optimized_ops::DepthwiseConv(params, input_shape, input_data, filter_shape,
+ filter_data, bias_shape, bias_data, output_shape,
+ output_data.data());
+
double sum_abs_diff = 0;
float max_abs_val = 0;
for (int i = 0; i < output_buffer_size; i++) {
@@ -59,27 +58,6 @@ void TestOneDepthwiseConv(const float* input_data, const Dims<4>& input_dims,
}
}
-void TestOneDepthwiseConv(FusedActivationFunctionType Ac,
- const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride, int pad_width, int pad_height,
- int depth_multiplier, const Dims<4>& output_dims) {
-#define TOCO_HANDLE_CASE(AC_TYPE) \
- if (AC_TYPE == Ac) { \
- TestOneDepthwiseConv<AC_TYPE>(input_data, input_dims, filter_data, \
- filter_dims, bias_data, bias_dims, stride, \
- pad_width, pad_height, depth_multiplier, \
- output_dims); \
- return; \
- }
- TOCO_HANDLE_CASE(FusedActivationFunctionType::kNone)
- TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu)
- TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu1)
- TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu6)
-#undef TOCO_HANDLE_CASE
-}
-
// This function picks some random DepthwiseConv params, which may or may not
// be legal. If they're not legal, it returns false. If they're legal,
// it runs the DepthwiseConv test and returns true. This allows the caller
@@ -99,6 +77,16 @@ bool TryTestOneDepthwiseConv() {
const int depth_multiplier = ExponentialRandomPositiveInt(0.8f, 6, 50);
const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
const int output_depth = input_depth * depth_multiplier;
+ const int dilation_width_factor = RandomElement(std::vector<int>({1, 2, 4}));
+ const int dilation_height_factor = RandomElement(std::vector<int>({1, 2, 4}));
+ float output_activation_min, output_activation_max;
+ FusedActivationFunctionType ac =
+ RandomElement(std::vector<FusedActivationFunctionType>(
+ {FusedActivationFunctionType::kNone,
+ FusedActivationFunctionType::kRelu,
+ FusedActivationFunctionType::kRelu1,
+ FusedActivationFunctionType::kRelu6}));
+ GetActivationMinMax(ac, &output_activation_min, &output_activation_max);
// The optimized DepthwiseConv implementation currently uses a fixed-size
// accumulator buffer on the stack, with that size. This currently means
// that it does not support larger output depths. It CHECK's for it,
@@ -109,27 +97,23 @@ bool TryTestOneDepthwiseConv() {
if (output_depth > kMaxSupportedOutputDepth) {
return false;
}
- const auto ac = RandomElement(std::vector<FusedActivationFunctionType>(
- {FusedActivationFunctionType::kNone, FusedActivationFunctionType::kRelu,
- FusedActivationFunctionType::kRelu6,
- FusedActivationFunctionType::kRelu1}));
- Dims<4> input_dims_inference =
- MakeDimsForInference(input_depth, input_width, input_height, batch);
- Dims<4> output_dims_inference;
+ RuntimeShape input_shape_inference(
+ {batch, input_height, input_width, input_depth});
+ RuntimeShape output_shape_inference;
int pad_width, pad_height;
const auto padding_type =
UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid;
- if (!ComputeConvSizes(input_dims_inference, output_depth, filter_width,
- filter_height, stride, padding_type,
- &output_dims_inference, &pad_width, &pad_height)) {
+ if (!ComputeConvSizes(input_shape_inference, output_depth, filter_width,
+ filter_height, stride, dilation_width_factor,
+ dilation_height_factor, padding_type,
+ &output_shape_inference, &pad_width, &pad_height)) {
return false;
}
- Dims<4> filter_dims_inference =
- MakeDimsForInference(output_depth, filter_width, filter_height, 1);
- Dims<4> bias_dims_inference = MakeDimsForInference(output_depth, 1, 1, 1);
- const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference);
- const int filter_buffer_size =
- RequiredBufferSizeForDims(filter_dims_inference);
+ RuntimeShape filter_shape_inference(
+ {1, filter_height, filter_width, output_depth});
+ RuntimeShape bias_shape_inference({1, 1, 1, output_depth});
+ const int input_buffer_size = input_shape_inference.FlatSize();
+ const int filter_buffer_size = filter_shape_inference.FlatSize();
std::vector<float> input_data(input_buffer_size);
std::vector<float> filter_data(filter_buffer_size);
std::vector<float> bias_data(output_depth);
@@ -140,10 +124,21 @@ bool TryTestOneDepthwiseConv() {
FillRandom(&input_data, -input_amplitude, input_amplitude);
FillRandom(&filter_data, -filter_amplitude, filter_amplitude);
FillRandom(&bias_data, -bias_amplitude, bias_amplitude);
- TestOneDepthwiseConv(ac, input_data.data(), input_dims_inference,
- filter_data.data(), filter_dims_inference,
- bias_data.data(), bias_dims_inference, stride, pad_width,
- pad_height, depth_multiplier, output_dims_inference);
+ DepthwiseParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride;
+ op_params.stride_height = stride;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ TestOneDepthwiseConv(op_params, input_shape_inference, input_data.data(),
+ filter_shape_inference, filter_data.data(),
+ bias_shape_inference, bias_data.data(),
+ output_shape_inference);
return true;
}
diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc
index 2c0fc8433e..9414e109c3 100644
--- a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc
@@ -35,29 +35,40 @@ namespace {
// Runs the DepthwiseConv and compares against the reference implementation.
template <FusedActivationFunctionType Ac>
int TestOneDepthwiseConvWithGivenOutputShift(
- const std::uint8_t* input_data, const Dims<4>& input_dims,
+ const std::uint8_t* input_data, const RuntimeShape& input_shape,
std::int32_t input_offset, const std::uint8_t* filter_data,
- const Dims<4>& filter_dims, std::int32_t filter_offset,
- const std::int32_t* bias_data, const Dims<4>& bias_dims, int stride,
+ const RuntimeShape& filter_shape, std::int32_t filter_offset,
+ const std::int32_t* bias_data, const RuntimeShape& bias_shape, int stride,
int pad_width, int pad_height, int depth_multiplier,
std::int32_t output_offset, std::int32_t output_multiplier,
int output_shift, std::int32_t output_activation_min,
- std::int32_t output_activation_max, const Dims<4>& output_dims) {
- const int output_buffer_size = RequiredBufferSizeForDims(output_dims);
+ std::int32_t output_activation_max, const RuntimeShape& output_shape) {
+ const int output_buffer_size = output_shape.FlatSize();
std::vector<std::uint8_t> output_data(output_buffer_size);
std::vector<std::uint8_t> reference_output_data(output_buffer_size);
- reference_ops::DepthwiseConv<Ac>(
- input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride, pad_width, pad_height,
- depth_multiplier, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max,
- reference_output_data.data(), output_dims);
- optimized_ops::DepthwiseConv<Ac>(
- input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride, pad_width, pad_height,
- depth_multiplier, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data.data(),
- output_dims);
+
+ tflite::DepthwiseParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride;
+ op_params.stride_height = stride;
+ op_params.dilation_width_factor = 1;
+ op_params.dilation_height_factor = 1;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = -output_shift;
+ reference_ops::DepthwiseConv(op_params, input_shape, input_data, filter_shape,
+ filter_data, bias_shape, bias_data, output_shape,
+ reference_output_data.data());
+ optimized_ops::DepthwiseConv(op_params, input_shape, input_data, filter_shape,
+ filter_data, bias_shape, bias_data, output_shape,
+ output_data.data());
int saturated_min = 0;
int saturated_max = 0;
std::vector<int> diff(output_buffer_size);
@@ -106,25 +117,25 @@ int TestOneDepthwiseConvWithGivenOutputShift(
// vacuous. So we just bisect our way to reasonable output_shift values.
template <FusedActivationFunctionType Ac>
void TestOneDepthwiseConvBisectOutputShift(
- const std::uint8_t* input_data, const Dims<4>& input_dims,
+ const std::uint8_t* input_data, const RuntimeShape& input_shape,
std::int32_t input_offset, const std::uint8_t* filter_data,
- const Dims<4>& filter_dims, std::int32_t filter_offset,
- const std::int32_t* bias_data, const Dims<4>& bias_dims, int stride,
+ const RuntimeShape& filter_shape, std::int32_t filter_offset,
+ const std::int32_t* bias_data, const RuntimeShape& bias_shape, int stride,
int pad_width, int pad_height, int depth_multiplier,
std::int32_t output_offset, std::int32_t output_multiplier,
int output_activation_bisect_start, int output_activation_bisect_end,
std::int32_t output_activation_min, std::int32_t output_activation_max,
- const Dims<4>& output_dims) {
+ const RuntimeShape& output_shape) {
ASSERT_LT(output_activation_bisect_start, output_activation_bisect_end)
<< "Bisection failed ?!?!";
int output_shift_bisect_midpoint =
(output_activation_bisect_start + output_activation_bisect_end) / 2;
int bisect_result = TestOneDepthwiseConvWithGivenOutputShift<Ac>(
- input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride, pad_width, pad_height,
+ input_data, input_shape, input_offset, filter_data, filter_shape,
+ filter_offset, bias_data, bias_shape, stride, pad_width, pad_height,
depth_multiplier, output_offset, output_multiplier,
output_shift_bisect_midpoint, output_activation_min,
- output_activation_max, output_dims);
+ output_activation_max, output_shape);
// At this point we know that the test succeeded (otherwise it would have
// aborted).
if (bisect_result == 0) {
@@ -147,47 +158,47 @@ void TestOneDepthwiseConvBisectOutputShift(
? output_activation_bisect_end
: output_shift_bisect_midpoint;
TestOneDepthwiseConvBisectOutputShift<Ac>(
- input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride, pad_width, pad_height,
+ input_data, input_shape, input_offset, filter_data, filter_shape,
+ filter_offset, bias_data, bias_shape, stride, pad_width, pad_height,
depth_multiplier, output_offset, output_multiplier,
new_output_activation_bisect_start, new_output_activation_bisect_end,
- output_activation_min, output_activation_max, output_dims);
+ output_activation_min, output_activation_max, output_shape);
}
template <FusedActivationFunctionType Ac>
void TestOneDepthwiseConv(
- const std::uint8_t* input_data, const Dims<4>& input_dims,
+ const std::uint8_t* input_data, const RuntimeShape& input_shape,
std::int32_t input_offset, const std::uint8_t* filter_data,
- const Dims<4>& filter_dims, std::int32_t filter_offset,
- const std::int32_t* bias_data, const Dims<4>& bias_dims, int stride,
+ const RuntimeShape& filter_shape, std::int32_t filter_offset,
+ const std::int32_t* bias_data, const RuntimeShape& bias_shape, int stride,
int pad_width, int pad_height, int depth_multiplier,
std::int32_t output_offset, std::int32_t output_multiplier,
std::int32_t output_activation_min, std::int32_t output_activation_max,
- const Dims<4>& output_dims) {
+ const RuntimeShape& output_shape) {
TestOneDepthwiseConvBisectOutputShift<Ac>(
- input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride, pad_width, pad_height,
+ input_data, input_shape, input_offset, filter_data, filter_shape,
+ filter_offset, bias_data, bias_shape, stride, pad_width, pad_height,
depth_multiplier, output_offset, output_multiplier, 0, 32,
- output_activation_min, output_activation_max, output_dims);
+ output_activation_min, output_activation_max, output_shape);
}
void TestOneDepthwiseConv(
FusedActivationFunctionType Ac, const std::uint8_t* input_data,
- const Dims<4>& input_dims, std::int32_t input_offset,
- const std::uint8_t* filter_data, const Dims<4>& filter_dims,
+ const RuntimeShape& input_shape, std::int32_t input_offset,
+ const std::uint8_t* filter_data, const RuntimeShape& filter_shape,
std::int32_t filter_offset, const std::int32_t* bias_data,
- const Dims<4>& bias_dims, int stride, int pad_width, int pad_height,
+ const RuntimeShape& bias_shape, int stride, int pad_width, int pad_height,
int depth_multiplier, std::int32_t output_offset,
std::int32_t output_multiplier, std::int32_t output_activation_min,
- std::int32_t output_activation_max, const Dims<4>& output_dims) {
-#define TOCO_HANDLE_CASE(AC_TYPE) \
- if (AC_TYPE == Ac) { \
- TestOneDepthwiseConv<AC_TYPE>( \
- input_data, input_dims, input_offset, filter_data, filter_dims, \
- filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, \
- depth_multiplier, output_offset, output_multiplier, \
- output_activation_min, output_activation_max, output_dims); \
- return; \
+ std::int32_t output_activation_max, const RuntimeShape& output_shape) {
+#define TOCO_HANDLE_CASE(AC_TYPE) \
+ if (AC_TYPE == Ac) { \
+ TestOneDepthwiseConv<AC_TYPE>( \
+ input_data, input_shape, input_offset, filter_data, filter_shape, \
+ filter_offset, bias_data, bias_shape, stride, pad_width, pad_height, \
+ depth_multiplier, output_offset, output_multiplier, \
+ output_activation_min, output_activation_max, output_shape); \
+ return; \
}
TOCO_HANDLE_CASE(FusedActivationFunctionType::kNone)
TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu)
@@ -199,6 +210,7 @@ void TestOneDepthwiseConv(
bool TryTestDepthwiseConv(int batch, int input_depth, int input_width,
int input_height, int filter_width, int filter_height,
int depth_multiplier, int stride,
+ int dilation_width_factor, int dilation_height_factor,
PaddingType padding_type) {
const int output_depth = input_depth * depth_multiplier;
// The optimized DepthwiseConv implementation currently uses a fixed-size
@@ -226,33 +238,33 @@ bool TryTestDepthwiseConv(int batch, int input_depth, int input_width,
const std::int32_t input_offset = UniformRandomInt(-256, 0);
const std::int32_t filter_offset = UniformRandomInt(-256, 0);
const std::int32_t output_offset = UniformRandomInt(-256, 0);
- Dims<4> input_dims_inference =
- MakeDimsForInference(input_depth, input_width, input_height, batch);
- Dims<4> output_dims_inference;
+ RuntimeShape input_shape_inference(
+ {batch, input_height, input_width, input_depth});
+ RuntimeShape output_shape_inference;
int pad_width, pad_height;
- if (!ComputeConvSizes(input_dims_inference, output_depth, filter_width,
- filter_height, stride, padding_type,
- &output_dims_inference, &pad_width, &pad_height)) {
+ if (!ComputeConvSizes(input_shape_inference, output_depth, filter_width,
+ filter_height, stride, dilation_width_factor,
+ dilation_height_factor, padding_type,
+ &output_shape_inference, &pad_width, &pad_height)) {
return false;
}
- Dims<4> filter_dims_inference =
- MakeDimsForInference(output_depth, filter_width, filter_height, 1);
- Dims<4> bias_dims_inference = MakeDimsForInference(output_depth, 1, 1, 1);
- const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference);
- const int filter_buffer_size =
- RequiredBufferSizeForDims(filter_dims_inference);
+ RuntimeShape filter_shape_inference(
+ {1, filter_height, filter_width, output_depth});
+ RuntimeShape bias_shape_inference({1, 1, 1, output_depth});
+ const int input_buffer_size = input_shape_inference.FlatSize();
+ const int filter_buffer_size = filter_shape_inference.FlatSize();
std::vector<std::uint8_t> input_data(input_buffer_size);
std::vector<std::uint8_t> filter_data(filter_buffer_size);
std::vector<std::int32_t> bias_data(output_depth);
FillRandom(&input_data);
FillRandom(&filter_data);
FillRandom(&bias_data, -10000, 10000);
- TestOneDepthwiseConv(ac, input_data.data(), input_dims_inference,
- input_offset, filter_data.data(), filter_dims_inference,
- filter_offset, bias_data.data(), bias_dims_inference,
+ TestOneDepthwiseConv(ac, input_data.data(), input_shape_inference,
+ input_offset, filter_data.data(), filter_shape_inference,
+ filter_offset, bias_data.data(), bias_shape_inference,
stride, pad_width, pad_height, depth_multiplier,
output_offset, output_multiplier, output_activation_min,
- output_activation_max, output_dims_inference);
+ output_activation_max, output_shape_inference);
return true;
}
@@ -274,12 +286,15 @@ bool TryTestOneDepthwiseConv() {
const int filter_height = ExponentialRandomPositiveInt(0.9f, 4, 10);
const int depth_multiplier = ExponentialRandomPositiveInt(0.8f, 6, 50);
const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
+ const int dilation_width_factor = RandomElement(std::vector<int>({1, 2, 4}));
+ const int dilation_height_factor = RandomElement(std::vector<int>({1, 2, 4}));
const auto padding_type =
UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid;
return TryTestDepthwiseConv(batch, input_depth, input_width, input_height,
filter_width, filter_height, depth_multiplier,
- stride, padding_type);
+ stride, dilation_width_factor,
+ dilation_height_factor, padding_type);
}
// Tests parameters for the 3x3 filter kernel.
@@ -292,6 +307,9 @@ bool TryTestOneDepthwiseConv3x3Filter() {
const int filter_height = 3;
const int depth_multiplier = 1;
const int stride = UniformRandomInt(1, 2);
+ // We don't support dilations in the 3x3 filter.
+ const int dilation_width_factor = 1;
+ const int dilation_height_factor = 1;
// Although the kernel supports only kValid padding, we test that kSame
// is using the correct code path.
const auto padding_type =
@@ -299,7 +317,8 @@ bool TryTestOneDepthwiseConv3x3Filter() {
return TryTestDepthwiseConv(batch, input_depth, input_width, input_height,
filter_width, filter_height, depth_multiplier,
- stride, padding_type);
+ stride, dilation_width_factor,
+ dilation_height_factor, padding_type);
}
void TestOneDepthwiseConv() {
diff --git a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
index 3624c20ae3..2252ca1bcc 100644
--- a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
@@ -43,11 +43,15 @@ void RunLogSoftmaxFloatReference(const uint8* input_data,
// Reference data generated via Dequant of input into float, and then applying
// float LogSoftmax.
- reference_ops::Dequantize(
- input_data, ToRuntimeDims(shape_common), input_offset, input_scale,
- reference_dequant_data.data(), ToRuntimeDims(shape_common));
- optimized_ops::LogSoftmax(reference_dequant_data.data(), shape_common,
- reference_output_float_data.data(), shape_common);
+ DequantizationParams dq_params;
+ dq_params.zero_point = input_offset;
+ dq_params.scale = input_scale;
+ reference_ops::Dequantize(dq_params, shape_common, input_data, shape_common,
+ reference_dequant_data.data());
+ SoftmaxParams sm_params;
+ optimized_ops::LogSoftmax(sm_params, shape_common,
+ reference_dequant_data.data(), shape_common,
+ reference_output_float_data.data());
// Work with quantized scaling for LogSoftmax, under which 255 represents 0,
// and -16 gets nudged up to 0.
for (int i = 0; i < ref_buffer_size; i++) {
@@ -129,14 +133,16 @@ void RunOneLogSoftmaxTest(const uint8* input_data,
const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits,
input_beta_left_shift);
- optimized_ops::LogSoftmax(input_data, shape_common, input_beta_multiplier,
- input_beta_left_shift, reverse_scaling_divisor,
- reverse_scaling_right_shift, diff_min,
- optimized_logsoftmax_output.data(), shape_common);
- reference_ops::LogSoftmax(
- input_data, shape_common, input_beta_multiplier, input_beta_left_shift,
- reverse_scaling_divisor, reverse_scaling_right_shift, diff_min,
- reference_quant_logsoftmax_output.data(), shape_common);
+ SoftmaxParams params;
+ params.input_multiplier = input_beta_multiplier;
+ params.input_left_shift = input_beta_left_shift;
+ params.reverse_scaling_divisor = reverse_scaling_divisor;
+ params.reverse_scaling_right_shift = reverse_scaling_right_shift;
+ params.diff_min = diff_min;
+ optimized_ops::LogSoftmax(params, shape_common, input_data, shape_common,
+ optimized_logsoftmax_output.data());
+ reference_ops::LogSoftmax(params, shape_common, input_data, shape_common,
+ reference_quant_logsoftmax_output.data());
CheckOutputData(optimized_logsoftmax_output.data(),
reference_float_logsoftmax_output.data(), shape_common,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h
index 4a90e7e640..2d96da65c3 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h
@@ -31,33 +31,50 @@ limitations under the License.
namespace tflite {
namespace cblas_ops {
-inline void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape,
+ float* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
gemmlowp::ScopedProfilingLabel label("Conv/cblas");
const float* gemm_input_data = nullptr;
- const Dims<4>* gemm_input_dims = nullptr;
- const int filter_width = ArraySize(filter_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
+ const RuntimeShape* gemm_input_shape = nullptr;
+ const int filter_width = filter_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
const bool need_im2col = stride_width != 1 || stride_height != 1 ||
filter_width != 1 || filter_height != 1;
if (need_im2col) {
TFLITE_DCHECK(im2col_data);
- optimized_ops::Im2col(input_data, input_dims, stride_width, stride_height,
- pad_width, pad_height, filter_height, filter_width, 0,
- im2col_data, im2col_dims);
+ ConvParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ optimized_ops::Im2col(op_params, filter_height, filter_width, 0,
+ input_shape, input_data, im2col_shape, im2col_data);
+
gemm_input_data = im2col_data;
- gemm_input_dims = &im2col_dims;
+ gemm_input_shape = &im2col_shape;
} else {
TFLITE_DCHECK(!im2col_data);
gemm_input_data = input_data;
- gemm_input_dims = &input_dims;
+ gemm_input_shape = &input_shape;
}
// The following code computes matrix multiplication c = a * transponse(b)
@@ -69,10 +86,10 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
const float* a = gemm_input_data;
const float* b = filter_data;
float* c = output_data;
- int m = gemm_input_dims->sizes[1] * gemm_input_dims->sizes[2] *
- gemm_input_dims->sizes[3];
- int n = output_dims.sizes[0];
- int k = gemm_input_dims->sizes[0];
+ const int gemm_input_dims = gemm_input_shape->DimensionsCount();
+ int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
+ int n = output_shape.Dims(3);
+ int k = gemm_input_shape->Dims(gemm_input_dims - 1);
// The stride of matrix a, b and c respectively.
int stride_a = k;
int stride_b = k;
@@ -82,8 +99,8 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
stride_a, b, stride_b, 0.0f, c, stride_c);
optimized_ops::AddBiasAndEvalActivationFunction(
- bias_data, bias_dims, output_data, output_dims, output_activation_min,
- output_activation_max);
+ output_activation_min, output_activation_max, bias_shape, bias_data,
+ output_shape, output_data);
}
} // namespace cblas_ops
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
index 70810ca784..d8dd7bba89 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
@@ -761,7 +761,8 @@ struct FloatDepthwiseConvKernel<true, 4, 1> {
// Accumulates the effect of one row of the filter, on a segment of one row
// of the output, accessing the corresponding one row of the input.
template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
-void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width,
+void FloatDepthwiseConvAccumRow(int stride, int dilation_factor,
+ int input_depth, int input_width,
const float* input_data, int pad_width,
int depth_multiplier, int filter_width,
const float* filter_data,
@@ -835,10 +836,10 @@ void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width,
// generic fallback of FloatDepthwiseConvAccumRow, portable, non-templatized.
inline void FloatDepthwiseConvAccumRowGeneric(
- int stride, int input_depth, int input_width, const float* input_data,
- int pad_width, int depth_multiplier, int filter_width,
- const float* filter_data, int out_x_buffer_start, int out_x_buffer_end,
- int output_depth, float* acc_buffer) {
+ int stride, int dilation_factor, int input_depth, int input_width,
+ const float* input_data, int pad_width, int depth_multiplier,
+ int filter_width, const float* filter_data, int out_x_buffer_start,
+ int out_x_buffer_end, int output_depth, float* acc_buffer) {
gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)");
#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
@@ -860,6 +861,7 @@ inline void FloatDepthwiseConvAccumRowGeneric(
<< "* stride = " << stride << "\n"
<< "* input_depth = " << input_depth << "\n"
<< "* depth_multiplier = " << depth_multiplier << "\n"
+ << "* dilation_factor = " << dilation_factor << "\n"
<< "*\n"
<< "* Please do not hesitate to contact benoitjacob@ with this\n"
<< "* information.\n"
@@ -869,14 +871,17 @@ inline void FloatDepthwiseConvAccumRowGeneric(
const float* filter_base_ptr = filter_data;
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
const int out_x_loop_start = std::max(
- out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride);
- const int out_x_loop_end =
- std::min(out_x_buffer_end,
- (pad_width + input_width - filter_x + stride - 1) / stride);
+ out_x_buffer_start,
+ (pad_width - dilation_factor * filter_x + stride - 1) / stride);
+ const int out_x_loop_end = std::min(
+ out_x_buffer_end,
+ (pad_width + input_width - dilation_factor * filter_x + stride - 1) /
+ stride);
float* acc_buffer_ptr =
acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
- const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
+ const int in_x_origin =
+ (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
const float* input_ptr = input_data + in_x_origin * input_depth;
const int input_ptr_increment = (stride - 1) * input_depth;
for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
@@ -907,25 +912,37 @@ inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth,
}
}
-inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
+inline void DepthwiseConv(
+ const DepthwiseParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data) {
gemmlowp::ScopedProfilingLabel label("DepthwiseConv");
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int depth_multiplier = params.depth_multiplier;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = input_shape.Dims(3);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
static const int kAccBufferMaxSize = 2048;
float acc_buffer[kAccBufferMaxSize];
@@ -946,7 +963,8 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
FIXED_DEPTH_MULTIPLIER) \
if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \
(input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \
- depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \
+ depth_multiplier == FIXED_DEPTH_MULTIPLIER && \
+ dilation_height_factor == 1 && dilation_width_factor == 1) { \
row_accum_func = \
FloatDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
FIXED_DEPTH_MULTIPLIER>; \
@@ -990,14 +1008,22 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
row_accum_func = FloatDepthwiseConvAccumRowGeneric;
}
+ const int input_height_stride = input_shape.Dims(3) * input_shape.Dims(2);
+ const int input_batch_stride = input_height_stride * input_shape.Dims(1);
+ const int filter_height_stride = filter_shape.Dims(3) * filter_shape.Dims(2);
+
// Now that we have determined row_accum_func, we can start work.
float* output_ptr = output_data;
for (int b = 0; b < batches; ++b) {
for (int out_y = 0; out_y < output_height; ++out_y) {
const int in_y_origin = (out_y * stride_height) - pad_height;
- const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_start =
+ std::max(0, (-in_y_origin + dilation_height_factor - 1) /
+ dilation_height_factor);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(filter_height,
+ (input_height - in_y_origin + dilation_height_factor - 1) /
+ dilation_height_factor);
for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
out_x_buffer_start += kOutputPixelsInAccBuffer) {
const int out_x_buffer_end = std::min(
@@ -1013,14 +1039,13 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
// Accumulation loop. Most of the time should be spent in here.
for (int filter_y = filter_y_start; filter_y < filter_y_end;
++filter_y) {
- const int in_y = in_y_origin + filter_y;
- row_accum_func(stride_width, input_depth, input_width,
- input_data + in_y * input_dims.strides[2] +
- b * input_dims.strides[3],
- pad_width, depth_multiplier, filter_width,
- filter_data + filter_y * filter_dims.strides[2],
- out_x_buffer_start, out_x_buffer_end, output_depth,
- acc_buffer);
+ const int in_y = in_y_origin + dilation_height_factor * filter_y;
+ row_accum_func(
+ stride_width, dilation_width_factor, input_depth, input_width,
+ input_data + in_y * input_height_stride + b * input_batch_stride,
+ pad_width, depth_multiplier, filter_width,
+ filter_data + filter_y * filter_height_stride, out_x_buffer_start,
+ out_x_buffer_end, output_depth, acc_buffer);
}
// Finished accumulating. Now store to destination.
const int num_output_values = output_depth * num_output_pixels;
@@ -1067,54 +1092,6 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
}
}
-inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height,
- int dilation_width_factor, int dilation_height_factor,
- int pad_width, int pad_height, int depth_multiplier,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
- // TODO(suharshs): Optimized implementation of dilation depthwise conv need to
- // be implemented.
- TFLITE_DCHECK(dilation_width_factor == 1);
- TFLITE_DCHECK(dilation_height_factor == 1);
-
- DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride_width, stride_height, pad_width, pad_height,
- depth_multiplier, output_activation_min, output_activation_max,
- output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride_width, stride_height, pad_width, pad_height,
- depth_multiplier, output_activation_min, output_activation_max,
- output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, int depth_multiplier,
- float* output_data, const Dims<4>& output_dims) {
- DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride, stride, pad_width, pad_height,
- depth_multiplier, output_data, output_dims);
-}
-
} // namespace optimized_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
index f707279600..803eff292a 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
@@ -1466,11 +1466,14 @@ struct QuantizedDepthwiseConvKernel<false, 12, 1> {
// Accumulates the effect of one row of the filter, on a segment of one row
// of the output, accessing the corresponding one row of the input.
template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
-void QuantizedDepthwiseConvAccumRow(
- int stride, int input_depth, int input_width, const uint8* input_data,
- int16 input_offset, int pad_width, int depth_multiplier, int filter_width,
- const uint8* filter_data, int16 filter_offset, int out_x_buffer_start,
- int out_x_buffer_end, int output_depth, int32* acc_buffer) {
+void QuantizedDepthwiseConvAccumRow(int stride, int dilation_factor,
+ int input_depth, int input_width,
+ const uint8* input_data, int16 input_offset,
+ int pad_width, int depth_multiplier,
+ int filter_width, const uint8* filter_data,
+ int16 filter_offset, int out_x_buffer_start,
+ int out_x_buffer_end, int output_depth,
+ int32* acc_buffer) {
#ifdef GEMMLOWP_PROFILING
gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__);
#endif
@@ -1537,10 +1540,11 @@ void QuantizedDepthwiseConvAccumRow(
// generic fallback of DepthwiseConvAccumRow, portable, non-templatized.
inline void QuantizedDepthwiseConvAccumRowGeneric(
- int stride, int input_depth, int input_width, const uint8* input_data,
- int16 input_offset, int pad_width, int depth_multiplier, int filter_width,
- const uint8* filter_data, int16 filter_offset, int out_x_buffer_start,
- int out_x_buffer_end, int output_depth, int32* acc_buffer) {
+ int stride, int dilation_factor, int input_depth, int input_width,
+ const uint8* input_data, int16 input_offset, int pad_width,
+ int depth_multiplier, int filter_width, const uint8* filter_data,
+ int16 filter_offset, int out_x_buffer_start, int out_x_buffer_end,
+ int output_depth, int32* acc_buffer) {
gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)");
#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
@@ -1562,6 +1566,7 @@ inline void QuantizedDepthwiseConvAccumRowGeneric(
<< "* stride = " << stride << "\n"
<< "* input_depth = " << input_depth << "\n"
<< "* depth_multiplier = " << depth_multiplier << "\n"
+ << "* dilation_factor = " << dilation_factor << "\n"
<< "*\n"
<< "* Please do not hesitate to contact benoitjacob@ with this\n"
<< "* information.\n"
@@ -1571,14 +1576,17 @@ inline void QuantizedDepthwiseConvAccumRowGeneric(
const uint8* filter_base_ptr = filter_data;
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
const int out_x_loop_start = std::max(
- out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride);
- const int out_x_loop_end =
- std::min(out_x_buffer_end,
- (pad_width + input_width - filter_x + stride - 1) / stride);
+ out_x_buffer_start,
+ (pad_width - dilation_factor * filter_x + stride - 1) / stride);
+ const int out_x_loop_end = std::min(
+ out_x_buffer_end,
+ (pad_width + input_width - dilation_factor * filter_x + stride - 1) /
+ stride);
int32* acc_buffer_ptr =
acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
- const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
+ const int in_x_origin =
+ (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
const uint8* input_ptr = input_data + in_x_origin * input_depth;
const int input_ptr_increment = (stride - 1) * input_depth;
for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
@@ -1669,33 +1677,48 @@ inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth,
}
}
-inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
+inline void DepthwiseConv(
+ const DepthwiseParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("DepthwiseConv/8bit");
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int depth_multiplier = params.depth_multiplier;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ TFLITE_DCHECK_GE(dilation_width_factor, 1);
+ TFLITE_DCHECK_GE(dilation_height_factor, 1);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
-
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = input_shape.Dims(3);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
#ifdef USE_NEON
- const bool shift_left = (output_shift <= 0);
- const int32 multiplier_power_of_two = shift_left ? (1 << -output_shift) : 1;
+ const bool shift_left = (output_shift > 0);
+ const int32 multiplier_power_of_two = shift_left ? (1 << output_shift) : 1;
#endif
- TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+ TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
// Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
// Jetson TX-2. This compiler does not support the offsetof() macro.
@@ -1703,14 +1726,12 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
// Call kernel optimized for depthwise convolutions using 3x3 filters if
// parameters are supported.
if (Fast3x3FilterKernelSupported(
- input_dims, filter_dims, stride_width, stride_height, pad_width,
- pad_height, depth_multiplier, output_dims, output_shift)) {
- DepthwiseConv3x3Filter(input_data, input_dims, input_offset, filter_data,
- filter_dims, filter_offset, bias_data, bias_dims,
- stride_width, stride_height, pad_width, pad_height,
- depth_multiplier, output_offset, output_multiplier,
- output_shift, output_activation_min,
- output_activation_max, output_data, output_dims);
+ input_shape, filter_shape, stride_width, stride_height,
+ dilation_width_factor, dilation_height_factor, pad_width, pad_height,
+ depth_multiplier, output_shape, output_shift)) {
+ DepthwiseConv3x3Filter(params, input_shape, input_data, filter_shape,
+ filter_data, bias_shape, bias_data, output_shape,
+ output_data);
return;
}
#endif
@@ -1734,7 +1755,8 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
FIXED_DEPTH_MULTIPLIER) \
if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \
(input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \
- depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \
+ depth_multiplier == FIXED_DEPTH_MULTIPLIER && \
+ dilation_width_factor == 1 && dilation_height_factor == 1) { \
row_accum_func = \
QuantizedDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
FIXED_DEPTH_MULTIPLIER>; \
@@ -1785,14 +1807,22 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
#undef TFMINI_USE_DEPTHWISECONV_KERNEL
+ const int input_height_stride = input_shape.Dims(3) * input_shape.Dims(2);
+ const int input_batch_stride = input_height_stride * input_shape.Dims(1);
+ const int filter_height_stride = filter_shape.Dims(3) * filter_shape.Dims(2);
+
// Now that we have determined row_accum_func, we can start work.
uint8* output_ptr = output_data;
for (int b = 0; b < batches; ++b) {
for (int out_y = 0; out_y < output_height; ++out_y) {
const int in_y_origin = (out_y * stride_height) - pad_height;
- const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_start =
+ std::max(0, (-in_y_origin + dilation_height_factor - 1) /
+ dilation_height_factor);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(filter_height,
+ (input_height - in_y_origin + dilation_height_factor - 1) /
+ dilation_height_factor);
for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
out_x_buffer_start += kOutputPixelsInAccBuffer) {
const int out_x_buffer_end = std::min(
@@ -1808,13 +1838,12 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
// Accumulation loop. Most of the time should be spent in here.
for (int filter_y = filter_y_start; filter_y < filter_y_end;
++filter_y) {
- const int in_y = in_y_origin + filter_y;
+ const int in_y = in_y_origin + dilation_height_factor * filter_y;
row_accum_func(
- stride_width, input_depth, input_width,
- input_data + in_y * input_dims.strides[2] +
- b * input_dims.strides[3],
+ stride_width, dilation_width_factor, input_depth, input_width,
+ input_data + in_y * input_height_stride + b * input_batch_stride,
input_offset, pad_width, depth_multiplier, filter_width,
- filter_data + filter_y * filter_dims.strides[2], filter_offset,
+ filter_data + filter_y * filter_height_stride, filter_offset,
out_x_buffer_start, out_x_buffer_end, output_depth, acc_buffer);
}
// Finished accumulating int32 values. Now need to convert them to
@@ -1845,7 +1874,7 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier);
}
for (int j = 0; j < 4; j++) {
- acc[j] = RoundingDivideByPOT(acc[j], output_shift);
+ acc[j] = RoundingDivideByPOT(acc[j], -output_shift);
}
} else {
// Fixed-point multiplication.
@@ -1889,8 +1918,8 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
// Rounding right shift.
- acc0 = RoundingDivideByPOT(acc0, output_shift);
- acc1 = RoundingDivideByPOT(acc1, output_shift);
+ acc0 = RoundingDivideByPOT(acc0, -output_shift);
+ acc1 = RoundingDivideByPOT(acc1, -output_shift);
} else {
// Fixed-point multiplication.
acc0 = vmulq_n_s32(acc0, multiplier_power_of_two);
@@ -1926,7 +1955,7 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
// Fixed-point multiplication.
acc = vqrdmulhq_n_s32(acc, output_multiplier);
// Rounding right shift.
- acc = RoundingDivideByPOT(acc, output_shift);
+ acc = RoundingDivideByPOT(acc, -output_shift);
} else {
// Fixed-point multiplication.
acc = vmulq_n_s32(acc, multiplier_power_of_two);
@@ -1953,7 +1982,7 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
for (; i < num_output_values; i++) {
int32 acc = acc_buffer[i];
acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
+ output_shift);
acc += output_offset;
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
@@ -1964,72 +1993,6 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height,
- int dilation_width_factor, int dilation_height_factor,
- int pad_width, int pad_height, int depth_multiplier,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- // TODO(suharshs): Optimized implementation of dilation depthwise is not
- // supported yet.
- TFLITE_DCHECK(dilation_width_factor == 1);
- TFLITE_DCHECK(dilation_height_factor == 1);
-
- DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride_width,
- stride_height, pad_width, pad_height, depth_multiplier,
- output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
-// Legacy, for compatibility with old checked-in code.
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride_width,
- stride_height, pad_width, pad_height, depth_multiplier,
- output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
-// Legacy, for compatibility with old checked-in code.
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, int depth_multiplier,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
- filter_dims, filter_offset, bias_data, bias_dims, stride,
- stride, pad_width, pad_height, depth_multiplier,
- output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
} // namespace optimized_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
index 0ce64f8c70..4809ddd02a 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
@@ -49,7 +49,7 @@ struct DepthwiseConvParams {
int32 output_multiplier;
int32 output_activation_min;
int32 output_activation_max;
- int32 output_shift;
+ int32 output_right_shift;
int32 input_width;
int32 input_height;
int32 stride_width;
@@ -75,7 +75,7 @@ struct DepthwiseConvParams {
#define OFFSET_OUTPUT_MULTIPLIER 52
#define OFFSET_OUTPUT_ACTIVATION_MIN 56
#define OFFSET_OUTPUT_ACTIVATION_MAX 60
-#define OFFSET_OUTPUT_SHIFT 64
+#define OFFSET_OUTPUT_RIGHT_SHIFT 64
#define OFFSET_INPUT_WIDTH 68
#define OFFSET_INPUT_HEIGHT 72
#define OFFSET_STRIDE_WIDTH 76
@@ -105,8 +105,8 @@ static_assert(offsetof(DepthwiseConvParams, output_activation_min) ==
OFFSET_OUTPUT_ACTIVATION_MIN, "");
static_assert(offsetof(DepthwiseConvParams, output_activation_max) ==
OFFSET_OUTPUT_ACTIVATION_MAX, "");
-static_assert(offsetof(DepthwiseConvParams, output_shift) ==
- OFFSET_OUTPUT_SHIFT, "");
+static_assert(offsetof(DepthwiseConvParams, output_right_shift) ==
+ OFFSET_OUTPUT_RIGHT_SHIFT, "");
static_assert(offsetof(DepthwiseConvParams, input_width) ==
OFFSET_INPUT_WIDTH, "");
static_assert(offsetof(DepthwiseConvParams, input_height) ==
@@ -189,7 +189,7 @@ struct DepthwiseConvWindow<8, 1, 1> {
"ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n"
"ldr w2, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
"dup v27.4s, w9\n"
- "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
"dup v29.4s, w2\n"
"ldr w4, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
"dup v30.4s, w4\n"
@@ -1166,7 +1166,7 @@ struct DepthwiseConvWindow<8, 2, 2> {
// values from time to time when there are not enough NEON registers.
// We use x9--x15 general purpose registers as they are caller-saved
// temporary registers (see http://infocenter.arm.com/help/topic/com.arm.doc.ihi0055b/IHI0055B_aapcs64.pdf). // NOLINT
- "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
"ldr w0, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n"
"cmp %w[output_window_height], #2\n"
"dup v28.8h, w0\n"
@@ -2216,7 +2216,7 @@ struct DepthwiseConvPartial<EdgeType::kCenter, 1, 1> {
"dup v27.4s, w10\n"
"ld1 {v0.8b}, [%[filter_ptr]], #8\n"
"cmp x11, #16\n"
- "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
"dup v28.4s, w9\n"
"ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
"neg w10, w10\n"
@@ -2355,7 +2355,7 @@ struct DepthwiseConvPartial<EdgeType::kCorner, 1, 1> {
"dup v26.8h, w6\n"
"ldr w6, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
"dup v27.4s, w7\n"
- "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
"dup v28.4s, w6\n"
"ldr w6, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
"neg w7, w7\n"
@@ -2532,7 +2532,7 @@ struct DepthwiseConvPartial<EdgeType::kHorizontal, 1, 1> {
"dup v26.8h, w12\n"
"ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
"dup v27.4s, w13\n"
- "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
"dup v28.4s, w12\n"
"ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
"neg w13, w13\n"
@@ -2739,7 +2739,7 @@ struct DepthwiseConvPartial<EdgeType::kVertical, 1, 1> {
"dup v26.8h, w12\n"
"ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
"dup v27.4s, w13\n"
- "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
"dup v28.4s, w12\n"
"ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
"neg w13, w13\n"
@@ -2910,7 +2910,7 @@ struct DepthwiseConvPartial<EdgeType::kVertical, 1, 1> {
#undef OFFSET_OUTPUT_MULTIPLIER
#undef OFFSET_OUTPUT_ACTIVATION_MIN
#undef OFFSET_OUTPUT_ACTIVATION_MAX
-#undef OFFSET_OUTPUT_SHIFT
+#undef OFFSET_OUTPUT_RIGHT_SHIFT
#undef OFFSET_INPUT_WIDTH
#undef OFFSET_INPUT_HEIGHT
#undef OFFSET_OUTPUT_WIDTH
@@ -3175,16 +3175,18 @@ inline void DepthwiseConvHandlePadding(const uint8* input_data,
}
inline bool Fast3x3FilterKernelSupported(
- const Dims<4>& input_dims, const Dims<4>& filter_dims, int32 stride_width,
- int32 stride_height, int32 pad_width, int32 pad_height,
- int32 depth_multiplier, const Dims<4>& output_dims, int32 output_shift) {
- const int32 input_height = ArraySize(input_dims, 2);
- const int32 input_width = ArraySize(input_dims, 1);
- const int32 input_depth = ArraySize(input_dims, 0);
- const int32 filter_height = ArraySize(filter_dims, 2);
- const int32 filter_width = ArraySize(filter_dims, 1);
- const int32 output_height = ArraySize(output_dims, 2);
- const int32 output_width = ArraySize(output_dims, 1);
+ const RuntimeShape& input_shape, const RuntimeShape& filter_shape,
+ int32 stride_width, int32 stride_height, int32 dilation_width_factor,
+ int32 dilation_height_factor, int32 pad_width, int32 pad_height,
+ int32 depth_multiplier, const RuntimeShape& output_shape,
+ int32 output_shift) {
+ const int32 input_height = input_shape.Dims(1);
+ const int32 input_width = input_shape.Dims(2);
+ const int32 input_depth = input_shape.Dims(3);
+ const int32 filter_height = filter_shape.Dims(1);
+ const int32 filter_width = filter_shape.Dims(2);
+ const int32 output_height = output_shape.Dims(1);
+ const int32 output_width = output_shape.Dims(2);
bool supported =
filter_width == 3 && filter_height == 3 && depth_multiplier == 1 &&
@@ -3192,7 +3194,8 @@ inline bool Fast3x3FilterKernelSupported(
(stride_height == 1 || stride_height == 2) &&
(stride_width == stride_height) && (pad_width == 0 || pad_width == 1) &&
(pad_height == 0 || pad_height == 1) && (pad_width == pad_height) &&
- (input_depth % 8) == 0 && (output_shift > 0);
+ (input_depth % 8) == 0 && (output_shift <= 0) &&
+ dilation_width_factor == 1 && dilation_height_factor == 1;
if (!supported) {
return false;
@@ -3234,36 +3237,47 @@ inline bool Fast3x3FilterKernelSupported(
}
inline void DepthwiseConv3x3Filter(
- const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
- const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims, int32 stride_width,
- int32 stride_height, int32 pad_width, int32 pad_height,
- int32 depth_multiplier, int32 output_offset, int32 output_multiplier,
- int32 output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
+ const DepthwiseParams& rt_params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__);
DepthwiseConvParams params;
- params.input_depth = ArraySize(input_dims, 0);
- params.input_width = ArraySize(input_dims, 1);
- params.input_height = ArraySize(input_dims, 2);
+
+ const int32 stride_width = rt_params.stride_width;
+ const int32 stride_height = rt_params.stride_height;
+ const int32 pad_width = rt_params.padding_values.width;
+ const int32 pad_height = rt_params.padding_values.height;
+ const int32 depth_multiplier = rt_params.depth_multiplier;
+ const int32 output_activation_min = rt_params.quantized_activation_min;
+ const int32 output_activation_max = rt_params.quantized_activation_max;
+ const int32 input_offset = rt_params.input_offset;
+ const int32 filter_offset = rt_params.weights_offset;
+ const int32 output_offset = rt_params.output_offset;
+ const int32 output_multiplier = rt_params.output_multiplier;
+ const int32 output_shift = rt_params.output_shift;
+
+ params.input_depth = input_shape.Dims(3);
+ params.input_width = input_shape.Dims(2);
+ params.input_height = input_shape.Dims(1);
params.input_row_size = params.input_depth * params.input_width;
params.input_offset = input_offset;
params.stride_width = stride_width;
params.stride_height = stride_height;
- params.output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
- params.output_width = ArraySize(output_dims, 1);
- params.output_height = ArraySize(output_dims, 2);
+ params.output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+ params.output_width = output_shape.Dims(2);
+ params.output_height = output_shape.Dims(1);
params.output_row_size = params.output_depth * params.output_width;
params.output_offset = output_offset;
params.filter_offset = filter_offset;
params.output_multiplier = output_multiplier;
- params.output_shift = output_shift;
+ params.output_right_shift = -output_shift;
params.output_activation_min = output_activation_min;
params.output_activation_max = output_activation_max;
- const int32 filter_height = ArraySize(filter_dims, 2);
- const int32 filter_width = ArraySize(filter_dims, 1);
+ const int32 filter_height = filter_shape.Dims(1);
+ const int32 filter_width = filter_shape.Dims(2);
params.filter_row_size = params.output_depth * filter_width;
// Algorithm assumes below constraints. It is optimized for depth
@@ -3279,7 +3293,7 @@ inline void DepthwiseConv3x3Filter(
TFLITE_DCHECK(pad_width == 0 || pad_width == 1);
TFLITE_DCHECK(pad_width == pad_height);
- const int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
const int64_t input_batch_size = params.input_row_size * params.input_height;
const int64_t output_batch_size =
params.output_row_size * params.output_height;
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
index b6151c40b3..4218be20a4 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
@@ -19,6 +19,8 @@ limitations under the License.
#include <sys/types.h>
#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
@@ -28,9 +30,857 @@ namespace optimized_ops {
// Unoptimized reference ops:
using reference_ops::ArgMax;
+using reference_ops::ArgMinMax;
+using reference_ops::Broadcast4DSlowGreater;
+using reference_ops::Broadcast4DSlowGreaterEqual;
+using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
+using reference_ops::Broadcast4DSlowGreaterWithScaling;
+using reference_ops::Broadcast4DSlowLess;
+using reference_ops::Broadcast4DSlowLessEqual;
+using reference_ops::Broadcast4DSlowLessEqualWithScaling;
+using reference_ops::Broadcast4DSlowLessWithScaling;
+using reference_ops::BroadcastAdd4DSlow;
+using reference_ops::BroadcastGreater;
+using reference_ops::BroadcastGreaterEqual;
+using reference_ops::BroadcastLess;
+using reference_ops::BroadcastLessEqual;
+using reference_ops::BroadcastMul4DSlow;
+using reference_ops::BroadcastSub4DSlow;
+using reference_ops::Concatenation;
+using reference_ops::ConcatenationWithScaling;
+using reference_ops::DepthConcatenation;
+using reference_ops::Dequantize;
+using reference_ops::Div;
+using reference_ops::FakeQuant;
+using reference_ops::Gather;
+using reference_ops::Greater;
+using reference_ops::GreaterEqual;
+using reference_ops::GreaterEqualWithScaling;
+using reference_ops::GreaterWithScaling;
+using reference_ops::Less;
+using reference_ops::LessEqual;
+using reference_ops::LessEqualWithScaling;
+using reference_ops::LessWithScaling;
+using reference_ops::Mean;
+using reference_ops::RankOneSelect;
using reference_ops::Relu1;
using reference_ops::Relu6;
+using reference_ops::ReluX;
+using reference_ops::Select;
using reference_ops::SpaceToBatchND;
+using reference_ops::Split;
+using reference_ops::StridedSlice;
+using reference_ops::TensorFlowSplit;
+using reference_ops::Transpose;
+
+static constexpr int kDepthwiseReverseShift = -1;
+
+template <typename Scalar, int N>
+VectorMap<Scalar> MapAsVector(Scalar* data, const Dims<N>& dims) {
+ const int size = FlatSize(dims);
+ return VectorMap<Scalar>(data, size, 1);
+}
+
+template <typename Scalar, int N>
+MatrixMap<Scalar> MapAsMatrixWithFirstDimAsRows(Scalar* data,
+ const Dims<N>& dims) {
+ const int rows = dims.sizes[0];
+ int cols = 1;
+ for (int d = 1; d < N; d++) {
+ cols *= dims.sizes[d];
+ }
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+template <typename Scalar, int N>
+MatrixMap<Scalar> MapAsMatrixWithLastDimAsCols(Scalar* data,
+ const Dims<N>& dims) {
+ const int cols = dims.sizes[N - 1];
+ int rows = 1;
+ for (int d = 0; d < N - 1; d++) {
+ rows *= dims.sizes[d];
+ }
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+template <typename Scalar, int N>
+ArrayMap<Scalar> MapAsArrayWithFirstDimAsRows(Scalar* data,
+ const Dims<N>& dims) {
+ const int rows = dims.sizes[0];
+ int cols = 1;
+ for (int d = 1; d < N; d++) {
+ cols *= dims.sizes[d];
+ }
+ return ArrayMap<Scalar>(data, rows, cols);
+}
+
+// TODO(b/62193649): this function is only needed as long
+// as we have the --variable_batch hack.
+template <typename Scalar, int N>
+MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
+ const Dims<N>& dims,
+ int rows) {
+ const int flatsize = FlatSize(dims);
+ TFLITE_DCHECK((flatsize % rows) == 0);
+ const int cols = flatsize / rows;
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) {
+ for (int i = 0; i < 4; i++) {
+ if (dims1.sizes[i] != dims2.sizes[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthwiseParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data);
+}
+
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, 1, 1, pad_width,
+ pad_height, depth_multiplier, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, pad_width, pad_height,
+ depth_multiplier, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ float* output_data, const Dims<4>& output_dims) {
+ DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride, stride, pad_width, pad_height,
+ depth_multiplier, output_data, output_dims);
+}
+
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthwiseParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kDepthwiseReverseShift * output_shift;
+
+ DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data);
+}
+
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
+ filter_dims, filter_offset, bias_data, bias_dims, stride,
+ stride, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+inline void AddBiasAndEvalActivationFunction(const float* bias_data,
+ const Dims<4>& bias_dims,
+ float* array_data,
+ const Dims<4>& array_dims,
+ float output_activation_min,
+ float output_activation_max) {
+ AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+ DimsToShape(bias_dims), bias_data,
+ DimsToShape(array_dims), array_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AddBiasAndEvalActivationFunction(const float* bias_data,
+ const Dims<4>& bias_dims,
+ float* array_data,
+ const Dims<4>& array_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ AddBiasAndEvalActivationFunction(bias_data, bias_dims, array_data, array_dims,
+ output_activation_min,
+ output_activation_max);
+}
+
+inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::FullyConnectedParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(weights_dims), weights_data,
+ DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data, const Dims<4>& weights_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
+ bias_dims, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data,
+ gemm_context);
+}
+
+inline void FullyConnected(
+ const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
+ const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset,
+ int32 output_multiplier, int output_shift, int32 output_activation_min,
+ int32 output_activation_max, int16* output_data, const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data_int32, DimsToShape(output_dims), output_data,
+ gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, output_offset,
+ output_multiplier, output_shift, output_activation_min,
+ output_activation_max, output_data, output_dims, gemm_context);
+}
+
+inline void ShuffledFullyConnected(
+ const uint8* input_data, const Dims<4>& input_dims,
+ const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
+ const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
+ int output_shift, int32 output_activation_min, int32 output_activation_max,
+ int16* output_data, const Dims<4>& output_dims,
+ uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(weights_dims), shuffled_weights_data,
+ DimsToShape(bias_dims), bias_data,
+ DimsToShape(output_dims), output_data,
+ shuffled_input_workspace_data, gemm_context);
+}
+
+template <typename T>
+inline void ExtractPatchIntoBufferColumn(
+ const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int in_width, int in_height, int in_depth, int single_buffer_length,
+ int buffer_id, const T* in_data, T* conv_buffer_data, uint8 zero_byte) {
+ ExtractPatchIntoBufferColumn(
+ DimsToShape(input_dims), w, h, b, kheight, kwidth, stride_width,
+ stride_height, pad_width, pad_height, in_width, in_height, in_depth,
+ single_buffer_length, buffer_id, in_data, conv_buffer_data, zero_byte);
+}
+
+template <typename T>
+void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
+ const Dims<4>& filter_dims, int stride_width,
+ int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ const Dims<4>& output_dims, uint8 zero_byte,
+ T* im2col_data) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+
+ DilatedIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), DimsToShape(output_dims),
+ im2col_data);
+}
+
+template <typename T>
+void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height, int kheight,
+ int kwidth, uint8 zero_byte, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = 1;
+ op_params.dilation_height_factor = 1;
+
+ Im2col(op_params, kheight, kwidth, zero_byte, DimsToShape(input_dims),
+ input_data, DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <typename T>
+void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int kheight, int kwidth,
+ uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
+ Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
+ kwidth, zero_byte, output_data, output_dims);
+}
+
+inline void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims,
+ float* im2col_data, const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+ filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
+ const int8_t* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float* scaling_factors_ptr,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims,
+ int8_t* im2col_data, const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ HybridConv(op_params, scaling_factors_ptr, DimsToShape(input_dims),
+ input_data, DimsToShape(filter_dims), filter_data,
+ DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride_width,
+ int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ float* output_data, const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
+ stride_width, stride_height, dilation_width_factor,
+ dilation_height_factor, pad_width, pad_height, output_activation_min,
+ output_activation_max, output_data, output_dims, im2col_data,
+ im2col_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
+ stride_width, stride_height, 1, 1, pad_width, pad_height,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
+ output_dims, im2col_data, im2col_dims);
+}
+
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims,
+ uint8* im2col_data, const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+ filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data, gemm_context);
+}
+
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
+ pad_width, pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width, stride_height,
+ pad_width, pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride, stride, pad_width,
+ pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac, typename T>
+void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int kheight, int kwidth,
+ uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
+ Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
+ kwidth, zero_byte, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void ConvAsGemm(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("ConvAsGemm");
+
+ const auto input_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
+ const auto filter_matrix_map =
+ MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
+ auto output_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+
+ Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
+
+ AddBiasAndEvalActivationFunction<Ac>(bias_data, bias_dims, output_data,
+ output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ gemmlowp::ScopedProfilingLabel label("ConvAsGemm/8bit");
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ const int input_rows = input_dims.sizes[0];
+ const int input_cols = FlatSizeSkipDim(input_dims, 0);
+ const int filter_rows = filter_dims.sizes[3];
+ const int filter_cols = FlatSizeSkipDim(filter_dims, 3);
+ const int output_rows = output_dims.sizes[0];
+ const int output_cols = FlatSizeSkipDim(output_dims, 0);
+ TFLITE_DCHECK_EQ(output_rows, filter_rows);
+ TFLITE_DCHECK_EQ(output_cols, input_cols);
+ TFLITE_DCHECK_EQ(filter_cols, input_rows);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+ gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
+ filter_data, output_rows, filter_cols, filter_cols);
+ gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
+ input_data, filter_cols, output_cols, filter_cols);
+ gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
+ output_data, output_rows, output_cols, output_rows);
+ const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
+ bias_data, output_rows, output_offset, output_multiplier, -output_shift,
+ output_activation_min, output_activation_max);
+ gemmlowp::GemmWithOutputPipeline<uint8, uint8,
+ gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
+ gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
+ input_offset, output_pipeline);
+}
+
+inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+
+ TransposeConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+template <typename T>
+void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
+ const Dims<4>& filter_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height,
+ const Dims<4>& output_dims, uint8 zero_byte,
+ T* im2col_data) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+
+ TransposeIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), DimsToShape(output_dims),
+ im2col_data);
+}
+
+inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
+ const float* prev_activ_data,
+ const Dims<4>& prev_activ_dims, const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims, const float* prev_state_data,
+ const Dims<4>& prev_state_dims, float* output_state_data,
+ const Dims<4>& output_state_dims, float* output_activ_data,
+ const Dims<4>& output_activ_dims, float* concat_temp_data,
+ const Dims<4>& concat_temp_dims, float* activ_temp_data,
+ const Dims<4>& activ_temp_dims) {
+ tflite::LstmCellParams op_params;
+ // Float LSTM cell does not need parameters to be set: leave untouched.
+
+ LstmCell(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(prev_activ_dims), prev_activ_data,
+ DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(prev_state_dims), prev_state_data,
+ DimsToShape(output_state_dims), output_state_data,
+ DimsToShape(output_activ_dims), output_activ_data,
+ DimsToShape(concat_temp_dims), concat_temp_data,
+ DimsToShape(activ_temp_dims), activ_temp_data);
+}
+
+template <int StateIntegerBits>
+void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
+ const uint8* prev_activ_data_uint8,
+ const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
+ const Dims<4>& weights_dims, const int32* bias_data_int32,
+ const Dims<4>& bias_dims, const int16* prev_state_data_int16,
+ const Dims<4>& prev_state_dims, int16* output_state_data_int16,
+ const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
+ const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
+ const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
+ const Dims<4>& activ_temp_dims, int32 weights_zero_point,
+ int32 accum_multiplier, int accum_shift,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::LstmCellParams op_params;
+ op_params.weights_zero_point = weights_zero_point;
+ op_params.accum_multiplier = accum_multiplier;
+ op_params.accum_shift = accum_shift;
+
+ LstmCell<StateIntegerBits>(
+ op_params, DimsToShape(input_dims), input_data_uint8,
+ DimsToShape(prev_activ_dims), prev_activ_data_uint8,
+ DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
+ bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
+ DimsToShape(output_state_dims), output_state_data_int16,
+ DimsToShape(output_activ_dims), output_activ_data_uint8,
+ DimsToShape(concat_temp_dims), concat_temp_data_uint8,
+ DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context);
+}
+
+template <typename T>
+void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastDiv4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
template <FusedActivationFunctionType Ac>
void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
@@ -574,6 +1424,14 @@ void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
filter_width, filter_height, output_data, output_dims);
}
+inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
+ float beta, float* output_data,
+ const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.beta = beta;
+ Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
inline void Softmax(const float* input_data, const Dims<4>& input_dims,
float beta, float* output_data,
const Dims<4>& output_dims) {
@@ -581,6 +1439,16 @@ inline void Softmax(const float* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_beta_multiplier, int32 input_beta_left_shift,
+ int diff_min, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.input_multiplier = input_beta_multiplier;
+ params.input_left_shift = input_beta_left_shift;
+ params.diff_min = diff_min;
+ Softmax(params, input_shape, input_data, output_shape, output_data);
+}
inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
int32 input_beta_multiplier, int32 input_beta_left_shift,
int diff_min, uint8* output_data,
@@ -590,12 +1458,33 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ // No params currently used for float LogSoftmax.
+ LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
LogSoftmax(input_data, DimsToShape(input_dims), output_data,
DimsToShape(output_dims));
}
+inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_multiplier, int32 input_left_shift,
+ int32 reverse_scaling_divisor,
+ int32 reverse_scaling_right_shift, int diff_min,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ params.reverse_scaling_divisor = reverse_scaling_divisor;
+ params.reverse_scaling_right_shift = reverse_scaling_right_shift;
+ params.diff_min = diff_min;
+ LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
int32 input_multiplier, int32 input_left_shift,
int32 reverse_scaling_divisor,
@@ -607,6 +1496,18 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ LogisticParams params;
+ params.input_zero_point = input_zero_point;
+ params.input_range_radius = input_range_radius;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
inline void Logistic(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
@@ -622,6 +1523,20 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+ const RuntimeShape& output_shape, int16* output_data) {
+ LogisticParams params;
+ // No params currently needed by int16 Logistic.
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
+ int16* output_data, const RuntimeShape& output_shape) {
+ LogisticParams params;
+ // No params currently needed by int16 Logistic.
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
int16* output_data, const Dims<4>& output_dims) {
Logistic(input_data, DimsToShape(input_dims), output_data,
@@ -634,6 +1549,18 @@ inline void Tanh(const float* input_data, const Dims<4>& input_dims,
output_data);
}
+inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ TanhParams params;
+ params.input_zero_point = input_zero_point;
+ params.input_range_radius = input_range_radius;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
int32 input_zero_point, int32 input_range_radius,
int32 input_multiplier, int input_left_shift,
@@ -643,6 +1570,14 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
+ int input_left_shift, int16* output_data,
+ const RuntimeShape& output_shape) {
+ TanhParams params;
+ params.input_left_shift = input_left_shift;
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
int input_left_shift, int16* output_data,
const Dims<4>& output_dims) {
@@ -777,7 +1712,6 @@ inline void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims,
DimsToShape(output_dims), output_data);
}
-// Legacy Dims<4>.
inline void LocalResponseNormalization(const float* input_data,
const Dims<4>& input_dims, int range,
float bias, float alpha, float beta,
@@ -793,7 +1727,6 @@ inline void LocalResponseNormalization(const float* input_data,
DimsToShape(output_dims), output_data);
}
-// Legacy Dims<4> version.
template <typename SrcT, typename DstT>
void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
const Dims<4>& output_dims) {
@@ -801,14 +1734,12 @@ void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
output_data);
}
-// Legacy Dims<4> version.
inline void Floor(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
output_data);
}
-// Legacy Dims<4>
inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
const int32* output_size_data,
const Dims<4>& output_size_dims, float* output_data,
@@ -820,7 +1751,6 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims), output_data);
}
-// Legacy Dims<4>
inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
const int32* output_size_data,
const Dims<4>& output_size_dims, uint8* output_data,
@@ -850,7 +1780,6 @@ inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
output_data, output_dims, /*align_corners=*/false);
}
-// Legacy Dims<4>.
template <typename T>
inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
const int32* block_shape_data,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
index 59f0e3c927..4139cf4eba 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
@@ -69,13 +69,13 @@ struct MatMulConvFunctor {
template <class T>
class EigenTensorConvFunctor {
private:
- Eigen::PaddingType TfLitePadding2EigenPadding(TfLitePadding padding) {
+ Eigen::PaddingType RuntimePadding2EigenPadding(PaddingType padding) {
switch (padding) {
- case kTfLitePaddingValid:
+ case PaddingType::kValid:
return Eigen::PADDING_VALID;
- case kTfLitePaddingSame:
+ case PaddingType::kSame:
return Eigen::PADDING_SAME;
- case kTfLitePaddingUnknown:
+ case PaddingType::kNone:
assert(false); // should never get here.
return Eigen::PADDING_VALID;
}
@@ -89,7 +89,7 @@ class EigenTensorConvFunctor {
int input_width, int input_depth, const T* filter_data,
int filter_height, int filter_width, int filter_count,
int stride_rows, int stride_cols, int pad_width,
- int pad_height, TfLitePadding padding, T* output_data,
+ int pad_height, PaddingType padding, T* output_data,
int output_height, int output_width) {
const bool is_1x1_kernel = (filter_height == 1 && filter_width == 1 &&
stride_rows == 1 && stride_cols == 1);
@@ -127,28 +127,38 @@ class EigenTensorConvFunctor {
input_depth, filter_count);
output.device(device) =
Eigen::SpatialConvolution(input, filter, stride_cols, stride_rows,
- TfLitePadding2EigenPadding(padding));
+ RuntimePadding2EigenPadding(padding));
}
}
};
-inline void Conv(const Eigen::ThreadPoolDevice& device, const float* input_data,
- const Dims<4>& input_dims, const float* filter_data,
- const Dims<4>& filter_dims, const float* bias_data,
- const Dims<4>& bias_dims, int stride_width, int stride_height,
- int pad_width, int pad_height, TfLitePadding padding,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims,
- float* im2col_data, const Dims<4>& im2col_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
- const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+inline void Conv(const Eigen::ThreadPoolDevice& device,
+ const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape,
+ float* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const PaddingType padding = params.padding_type;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
EigenTensorConvFunctor<float> conv_functor;
conv_functor(device, input_data, im2col_data, batches, input_height,
input_width, input_depth, filter_data, filter_height,
@@ -157,8 +167,8 @@ inline void Conv(const Eigen::ThreadPoolDevice& device, const float* input_data,
output_width);
optimized_ops::AddBiasAndEvalActivationFunction(
- bias_data, bias_dims, output_data, output_dims, output_activation_min,
- output_activation_max);
+ output_activation_min, output_activation_max, bias_shape, bias_data,
+ output_shape, output_data);
}
} // namespace multithreaded_ops
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 27418178fd..36c15dbc57 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -457,7 +457,7 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size,
return;
}
*scaling_factor = range / kScale;
- const float scaling_factor_inv = 1.0f / *scaling_factor;
+ const float scaling_factor_inv = kScale / range;
const int postamble_start =
size - (size & (2 * kFloatWeightsPerNeonLane - 1));
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 2fa5d6445e..77f84e0c1c 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -52,13 +52,10 @@ using reference_ops::Broadcast4DSlowLessEqual;
using reference_ops::Broadcast4DSlowLessEqualWithScaling;
using reference_ops::Broadcast4DSlowLessWithScaling;
using reference_ops::BroadcastAdd4DSlow;
-using reference_ops::BroadcastGreater;
-using reference_ops::BroadcastGreaterEqual;
-using reference_ops::BroadcastLess;
-using reference_ops::BroadcastLessEqual;
using reference_ops::BroadcastMul4DSlow;
using reference_ops::BroadcastSub4DSlow;
using reference_ops::Concatenation;
+using reference_ops::ConcatenationWithScaling;
using reference_ops::DepthConcatenation;
using reference_ops::Dequantize;
using reference_ops::Div;
@@ -81,14 +78,13 @@ using reference_ops::Select;
using reference_ops::SpaceToBatchND;
using reference_ops::Split;
using reference_ops::StridedSlice;
-using reference_ops::TensorFlowSplit;
using reference_ops::Transpose;
// TODO(b/80247582) Remove this constant.
// This will be phased out as the shifts are revised with more thought. Use of a
// constant enables us to track progress on this work.
//
-// Used mainly to convert from old-style shifts (right) to new-style (left).
+// Used to convert from old-style shifts (right) to new-style (left).
static constexpr int kReverseShift = -1;
// Make a local VectorMap typedef allowing to map a float array
@@ -111,12 +107,6 @@ VectorMap<Scalar> MapAsVector(Scalar* data, const RuntimeShape& shape) {
return VectorMap<Scalar>(data, size, 1);
}
-template <typename Scalar, int N>
-VectorMap<Scalar> MapAsVector(Scalar* data, const Dims<N>& dims) {
- const int size = FlatSize(dims);
- return VectorMap<Scalar>(data, size, 1);
-}
-
// Make a local VectorMap typedef allowing to map a float array
// as a Eigen matrix expression. The same explanation as for VectorMap
// above also applies here.
@@ -144,28 +134,6 @@ MatrixMap<Scalar> MapAsMatrixWithFirstDimAsCols(Scalar* data,
return MatrixMap<Scalar>(data, rows, cols);
}
-template <typename Scalar, int N>
-MatrixMap<Scalar> MapAsMatrixWithFirstDimAsRows(Scalar* data,
- const Dims<N>& dims) {
- const int rows = dims.sizes[0];
- int cols = 1;
- for (int d = 1; d < N; d++) {
- cols *= dims.sizes[d];
- }
- return MatrixMap<Scalar>(data, rows, cols);
-}
-
-template <typename Scalar, int N>
-MatrixMap<Scalar> MapAsMatrixWithLastDimAsCols(Scalar* data,
- const Dims<N>& dims) {
- const int cols = dims.sizes[N - 1];
- int rows = 1;
- for (int d = 0; d < N - 1; d++) {
- rows *= dims.sizes[d];
- }
- return MatrixMap<Scalar>(data, rows, cols);
-}
-
template <typename Scalar>
using ArrayMap = typename std::conditional<
std::is_const<Scalar>::value,
@@ -173,17 +141,6 @@ using ArrayMap = typename std::conditional<
Eigen::Dynamic, Eigen::Dynamic>>,
Eigen::Map<Eigen::Array<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
-template <typename Scalar, int N>
-ArrayMap<Scalar> MapAsArrayWithFirstDimAsRows(Scalar* data,
- const Dims<N>& dims) {
- const int rows = dims.sizes[0];
- int cols = 1;
- for (int d = 1; d < N; d++) {
- cols *= dims.sizes[d];
- }
- return ArrayMap<Scalar>(data, rows, cols);
-}
-
template <typename Scalar>
ArrayMap<Scalar> MapAsArrayWithLastDimAsRows(Scalar* data,
const RuntimeShape& shape) {
@@ -205,20 +162,6 @@ struct TTypes {
UnalignedConstMatrix;
};
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-// TODO(b/62193649): this function is only needed as long
-// as we have the --variable_batch hack.
-template <typename Scalar, int N>
-MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
- const Dims<N>& dims,
- int rows) {
- const int flatsize = FlatSize(dims);
- TFLITE_DCHECK((flatsize % rows) == 0);
- const int cols = flatsize / rows;
- return MatrixMap<Scalar>(data, rows, cols);
-}
-
// TODO(b/62193649): this function is only needed as long
// as we have the --variable_batch hack.
template <typename Scalar>
@@ -270,15 +213,6 @@ SaturatingRoundingMultiplyByPOTParam(
SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
}
-inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) {
- for (int i = 0; i < 4; i++) {
- if (dims1.sizes[i] != dims2.sizes[i]) {
- return false;
- }
- }
- return true;
-}
-
inline void AddBiasAndEvalActivationFunction(float output_activation_min,
float output_activation_max,
const RuntimeShape& bias_shape,
@@ -352,33 +286,6 @@ inline void AddBiasAndEvalActivationFunction(float output_activation_min,
#endif
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void AddBiasAndEvalActivationFunction(const float* bias_data,
- const Dims<4>& bias_dims,
- float* array_data,
- const Dims<4>& array_dims,
- float output_activation_min,
- float output_activation_max) {
- AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
- DimsToShape(bias_dims), bias_data,
- DimsToShape(array_dims), array_data);
-}
-
-// Note: This to be converted to RuntimeShapes along with Conv.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AddBiasAndEvalActivationFunction(const float* bias_data,
- const Dims<4>& bias_dims,
- float* array_data,
- const Dims<4>& array_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- AddBiasAndEvalActivationFunction(bias_data, bias_dims, array_data, array_dims,
- output_activation_min,
- output_activation_max);
-}
-
template <typename Lhs, typename Rhs, typename Result>
void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
Eigen::MatrixBase<Result>* result) {
@@ -925,38 +832,6 @@ inline void FullyConnected(
output_data);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
- const float* weights_data,
- const Dims<4>& weights_dims, const float* bias_data,
- const Dims<4>& bias_dims,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
- tflite::FullyConnectedParams op_params;
- op_params.float_activation_min = output_activation_min;
- op_params.float_activation_max = output_activation_max;
-
- FullyConnected(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(weights_dims), weights_data,
- DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
- output_data);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void FullyConnected(const float* input_data, const Dims<4>& input_dims,
- const float* weights_data, const Dims<4>& weights_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
- bias_dims, output_activation_min, output_activation_max,
- output_data, output_dims);
-}
-
#ifdef USE_NEON
inline void FullyConnectedAsGEMV(
const RuntimeShape& input_shape, const uint8* input_data,
@@ -977,7 +852,7 @@ inline void FullyConnectedAsGEMV(
const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
output_shape, output_dim_count - 1);
static constexpr int kPeel = 4;
- const bool shift_left = (output_shift <= 0);
+ const bool shift_left = (output_shift > 0);
for (int k = 0; k < input_size; k += 64) {
optimized_ops_preload_l1_stream(input_data + k);
}
@@ -1090,7 +965,7 @@ inline void FullyConnectedAsGEMV(
bias_ptr += 4;
reduced = vaddq_s32(reduced, bias_vec);
if (shift_left) {
- const int32 multiplier_power_of_two = 1 << -output_shift;
+ const int32 multiplier_power_of_two = 1 << output_shift;
reduced = vmulq_n_s32(reduced, multiplier_power_of_two);
reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
} else {
@@ -1098,7 +973,7 @@ inline void FullyConnectedAsGEMV(
reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
// Rounding-shift-right.
using gemmlowp::RoundingDivideByPOT;
- reduced = RoundingDivideByPOT(reduced, output_shift);
+ reduced = RoundingDivideByPOT(reduced, -output_shift);
}
// Add the output offset.
const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
@@ -1195,7 +1070,7 @@ inline void FullyConnected(
gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
output_data, output_rows, batches, output_rows);
const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
- bias_data, output_rows, output_offset, output_multiplier, -output_shift,
+ bias_data, output_rows, output_offset, output_multiplier, output_shift,
output_activation_min, output_activation_max);
gemmlowp::GemmWithOutputPipeline<uint8, uint8,
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
@@ -1203,32 +1078,6 @@ inline void FullyConnected(
input_offset, output_pipeline);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
- tflite::FullyConnectedParams op_params;
- op_params.input_offset = input_offset;
- op_params.weights_offset = filter_offset;
- op_params.output_offset = output_offset;
- op_params.output_multiplier = output_multiplier;
- op_params.output_shift = output_shift;
- op_params.quantized_activation_min = output_activation_min;
- op_params.quantized_activation_max = output_activation_max;
-
- FullyConnected(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
- bias_data, DimsToShape(output_dims), output_data,
- gemm_context);
-}
-
inline void FullyConnected(
const FullyConnectedParams& params, const RuntimeShape& input_shape,
const uint8* input_data, const RuntimeShape& filter_shape,
@@ -1274,14 +1123,14 @@ inline void FullyConnected(
if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 64)) {
GEMVForLstmCellWithSymmetricRange(
input_shape, input_data, filter_shape, filter_data, bias_shape,
- bias_data_int32, output_multiplier, -output_shift, output_shape,
+ bias_data_int32, output_multiplier, output_shift, output_shape,
output_data);
return;
}
if (!(output_depth % 4) && !(accum_depth % 8)) {
GEMVForLstmCell(input_shape, input_data, filter_shape, filter_data,
filter_offset, bias_shape, bias_data_int32,
- output_multiplier, -output_shift, output_shape,
+ output_multiplier, output_shift, output_shape,
output_data);
return;
}
@@ -1302,7 +1151,7 @@ inline void FullyConnected(
scale_stage.result_offset_after_shift = 0;
scale_stage.result_fixedpoint_multiplier = output_multiplier;
// Note that this shift is negated wrt ordinary FC.
- scale_stage.result_exponent = -output_shift;
+ scale_stage.result_exponent = output_shift;
gemmlowp::OutputStageClamp clamp_stage;
clamp_stage.min = output_activation_min;
clamp_stage.max = output_activation_max;
@@ -1316,53 +1165,6 @@ inline void FullyConnected(
input_offset, output_pipeline);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void FullyConnected(
- const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
- const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset,
- int32 output_multiplier, int output_shift, int32 output_activation_min,
- int32 output_activation_max, int16* output_data, const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
- tflite::FullyConnectedParams op_params;
- op_params.input_offset = input_offset;
- op_params.weights_offset = filter_offset;
- op_params.output_offset = output_offset;
- op_params.output_multiplier = output_multiplier;
- op_params.output_shift = output_shift;
- op_params.quantized_activation_min = output_activation_min;
- op_params.quantized_activation_max = output_activation_max;
-
- FullyConnected(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
- bias_data_int32, DimsToShape(output_dims), output_data,
- gemm_context);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_dims, gemm_context);
-}
-
// Internal function doing the actual arithmetic work for
// ShuffledFullyConnected.
// May be called either directly by it (single-threaded case) or may be used
@@ -1376,8 +1178,8 @@ inline void ShuffledFullyConnectedWorkerImpl(
#if defined USE_NEON
const int8* shuffled_weights_ptr = shuffled_weights_data;
if (batches == 1) {
- const int right_shift = output_shift > 0 ? output_shift : 0;
- const int left_shift = output_shift > 0 ? 0 : -output_shift;
+ const int right_shift = output_shift > 0 ? 0 : -output_shift;
+ const int left_shift = output_shift > 0 ? output_shift : 0;
for (int c = 0; c < output_depth; c += 4) {
// Accumulation loop.
int32x4_t row_accum0 = vdupq_n_s32(0);
@@ -1443,8 +1245,8 @@ inline void ShuffledFullyConnectedWorkerImpl(
vst1_s16(output_data + c, res16);
}
} else if (batches == 4) {
- const int right_shift = output_shift > 0 ? output_shift : 0;
- const int left_shift = output_shift > 0 ? 0 : -output_shift;
+ const int right_shift = output_shift > 0 ? 0 : -output_shift;
+ const int left_shift = output_shift > 0 ? output_shift : 0;
for (int c = 0; c < output_depth; c += 4) {
const int8* shuffled_input_ptr =
reinterpret_cast<const int8*>(shuffled_input_workspace_data);
@@ -1575,8 +1377,8 @@ inline void ShuffledFullyConnectedWorkerImpl(
// (16-bit, typically 3 integer bits) fixed-point format. The quantized
// multiplier and shift here have been pre-computed offline
// (e.g. by toco).
- acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
+ acc =
+ MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
// Saturate, cast to int16, and store to output array.
acc = std::max(acc, -32768);
acc = std::min(acc, 32767);
@@ -1627,7 +1429,7 @@ inline void ShuffledFullyConnectedWorkerImpl(
// quantized multiplier and shift here have been pre-computed offline
// (e.g. by toco).
acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
+ output_shift);
// Saturate, cast to int16, and store to output array.
acc = std::max(acc, -32768);
acc = std::min(acc, 32767);
@@ -1807,28 +1609,6 @@ inline void ShuffledFullyConnected(
gemm_context->workers_pool()->Execute(tasks);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void ShuffledFullyConnected(
- const uint8* input_data, const Dims<4>& input_dims,
- const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
- const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- int16* output_data, const Dims<4>& output_dims,
- uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
- tflite::FullyConnectedParams op_params;
- op_params.output_multiplier = output_multiplier;
- op_params.output_shift = output_shift;
- op_params.quantized_activation_min = output_activation_min;
- op_params.quantized_activation_max = output_activation_max;
-
- ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(weights_dims), shuffled_weights_data,
- DimsToShape(bias_dims), bias_data,
- DimsToShape(output_dims), output_data,
- shuffled_input_workspace_data, gemm_context);
-}
-
template <typename T>
inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w,
int h, int b, int kheight, int kwidth,
@@ -1919,20 +1699,6 @@ inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T>
-inline void ExtractPatchIntoBufferColumn(
- const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int in_width, int in_height, int in_depth, int single_buffer_length,
- int buffer_id, const T* in_data, T* conv_buffer_data, uint8 zero_byte) {
- ExtractPatchIntoBufferColumn(
- DimsToShape(input_dims), w, h, b, kheight, kwidth, stride_width,
- stride_height, pad_width, pad_height, in_width, in_height, in_depth,
- single_buffer_length, buffer_id, in_data, conv_buffer_data, zero_byte);
-}
-
template <typename T>
void DilatedIm2col(const ConvParams& params, uint8 zero_byte,
const RuntimeShape& input_shape, const T* input_data,
@@ -2016,30 +1782,6 @@ void DilatedIm2col(const ConvParams& params, uint8 zero_byte,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T>
-void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
- const Dims<4>& filter_dims, int stride_width,
- int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- const Dims<4>& output_dims, uint8 zero_byte,
- T* im2col_data) {
- tflite::ConvParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
- op_params.dilation_width_factor = dilation_width_factor;
- op_params.dilation_height_factor = dilation_height_factor;
-
- DilatedIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
- DimsToShape(filter_dims), DimsToShape(output_dims),
- im2col_data);
-}
-
template <typename T>
void Im2col(const ConvParams& params, int kheight, int kwidth, uint8 zero_byte,
const RuntimeShape& input_shape, const T* input_data,
@@ -2075,36 +1817,6 @@ void Im2col(const ConvParams& params, int kheight, int kwidth, uint8 zero_byte,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T>
-void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
- int stride_height, int pad_width, int pad_height, int kheight,
- int kwidth, uint8 zero_byte, T* output_data,
- const Dims<4>& output_dims) {
- tflite::ConvParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
- op_params.dilation_width_factor = 1;
- op_params.dilation_height_factor = 1;
-
- Im2col(op_params, kheight, kwidth, zero_byte, DimsToShape(input_dims),
- input_data, DimsToShape(output_dims), output_data);
-}
-
-// legacy, for compatibility with old checked-in code
-template <typename T>
-void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int kheight, int kwidth,
- uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
- Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
- kwidth, zero_byte, output_data, output_dims);
-}
-
inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& filter_shape,
const float* filter_data, const RuntimeShape& bias_shape,
@@ -2168,33 +1880,6 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
output_data);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims,
- float* im2col_data, const Dims<4>& im2col_dims) {
- tflite::ConvParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
- op_params.dilation_width_factor = dilation_width_factor;
- op_params.dilation_height_factor = dilation_height_factor;
- op_params.float_activation_min = output_activation_min;
- op_params.float_activation_max = output_activation_max;
-
- Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
- filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
- output_data, DimsToShape(im2col_dims), im2col_data);
-}
-
inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
const RuntimeShape& input_shape,
const int8_t* input_data,
@@ -2210,7 +1895,6 @@ inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(im2col_shape.DimensionsCount(), 4);
const int batch_size = input_shape.Dims(0);
const int filter_width = filter_shape.Dims(2);
@@ -2254,10 +1938,7 @@ inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
const int output_rows = FlatSizeSkipDim(output_shape, 3);
TFLITE_DCHECK_EQ(output_cols, filter_rows);
TFLITE_DCHECK_EQ(output_rows, gemm_input_rows);
- TFLITE_DCHECK_EQ(bias_shape.Dims(3), output_cols);
- TFLITE_DCHECK_EQ(bias_shape.Dims(2), 1);
- TFLITE_DCHECK_EQ(bias_shape.Dims(1), 1);
- TFLITE_DCHECK_EQ(bias_shape.Dims(0), 1);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_cols);
// MatrixBatchVectorMultiplyAccumulate assumes that each row of the second
// input matrix has its own scale factor. This code duplicates the scale
@@ -2279,82 +1960,6 @@ inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
output_data);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
- const int8_t* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float* scaling_factors_ptr,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims,
- int8_t* im2col_data, const Dims<4>& im2col_dims) {
- tflite::ConvParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
- op_params.float_activation_min = output_activation_min;
- op_params.float_activation_max = output_activation_max;
-
- HybridConv(op_params, scaling_factors_ptr, DimsToShape(input_dims),
- input_data, DimsToShape(filter_dims), filter_data,
- DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
- output_data, DimsToShape(im2col_dims), im2col_data);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <FusedActivationFunctionType Ac>
-void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride_width,
- int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- float* output_data, const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
- stride_width, stride_height, dilation_width_factor,
- dilation_height_factor, pad_width, pad_height, output_activation_min,
- output_activation_max, output_data, output_dims, im2col_data,
- im2col_dims);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride_width,
- int stride_height, int pad_width, int pad_height, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
- stride_width, stride_height, 1, 1, pad_width, pad_height,
- output_activation_min, output_activation_max, output_data, output_dims,
- im2col_data, im2col_dims);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
- output_dims, im2col_data, im2col_dims);
-}
-
inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
const uint8* input_data, const RuntimeShape& filter_shape,
const uint8* filter_data, const RuntimeShape& bias_shape,
@@ -2376,7 +1981,6 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(im2col_shape.DimensionsCount(), 4);
const uint8* gemm_input_data = nullptr;
const RuntimeShape* gemm_input_shape = nullptr;
@@ -2439,192 +2043,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
output_data, output_rows, output_cols);
const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
- bias_data, output_rows, output_offset, output_multiplier, -output_shift,
- output_activation_min, output_activation_max);
- gemmlowp::GemmWithOutputPipeline<uint8, uint8,
- gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
- gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
- input_offset, output_pipeline);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims,
- uint8* im2col_data, const Dims<4>& im2col_dims,
- gemmlowp::GemmContext* gemm_context) {
- tflite::ConvParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
- op_params.dilation_width_factor = dilation_width_factor;
- op_params.dilation_height_factor = dilation_height_factor;
- op_params.input_offset = input_offset;
- op_params.weights_offset = filter_offset;
- op_params.output_offset = output_offset;
- op_params.output_multiplier = output_multiplier;
- op_params.output_shift = output_shift;
- op_params.quantized_activation_min = output_activation_min;
- op_params.quantized_activation_max = output_activation_max;
-
- Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
- filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
- output_data, DimsToShape(im2col_dims), im2col_data, gemm_context);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims,
- gemmlowp::GemmContext* gemm_context) {
- Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
- pad_width, pad_height, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data, output_dims,
- im2col_data, im2col_dims, gemm_context);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims,
- gemmlowp::GemmContext* gemm_context) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride_width, stride_height,
- pad_width, pad_height, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data, output_dims,
- im2col_data, im2col_dims, gemm_context);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride, stride, pad_width,
- pad_height, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data, output_dims,
- im2col_data, im2col_dims, gemm_context);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac, typename T>
-void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int kheight, int kwidth,
- uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
- Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
- kwidth, zero_byte, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void ConvAsGemm(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- float* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("ConvAsGemm");
-
- const auto input_matrix_map =
- MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- const auto filter_matrix_map =
- MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
- auto output_matrix_map =
- MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
-
- Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
-
- AddBiasAndEvalActivationFunction<Ac>(bias_data, bias_dims, output_data,
- output_dims);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
- gemmlowp::ScopedProfilingLabel label("ConvAsGemm/8bit");
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- const int input_rows = input_dims.sizes[0];
- const int input_cols = FlatSizeSkipDim(input_dims, 0);
- const int filter_rows = filter_dims.sizes[3];
- const int filter_cols = FlatSizeSkipDim(filter_dims, 3);
- const int output_rows = output_dims.sizes[0];
- const int output_cols = FlatSizeSkipDim(output_dims, 0);
- TFLITE_DCHECK_EQ(output_rows, filter_rows);
- TFLITE_DCHECK_EQ(output_cols, input_cols);
- TFLITE_DCHECK_EQ(filter_cols, input_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
- gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
- filter_data, output_rows, filter_cols, filter_cols);
- gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
- input_data, filter_cols, output_cols, filter_cols);
- gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
- output_data, output_rows, output_cols, output_rows);
- const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
- bias_data, output_rows, output_offset, output_multiplier, -output_shift,
+ bias_data, output_rows, output_offset, output_multiplier, output_shift,
output_activation_min, output_activation_max);
gemmlowp::GemmWithOutputPipeline<uint8, uint8,
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
@@ -2794,6 +2213,7 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
*output_inv_sqrt <<= -*output_shift;
*output_shift = 0;
}
+ // Convert right shift (right is positive) to left shift.
*output_shift *= kReverseShift;
}
@@ -3547,21 +2967,6 @@ void BroadcastDiv4DSlow(const ArithmeticParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4>.
-template <typename T>
-void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- tflite::ArithmeticParams op_params;
- SetActivationParams(output_activation_min, output_activation_max, &op_params);
-
- BroadcastDiv4DSlow(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data);
-}
-
// TODO(aselle): This is not actually optimized yet.
inline void SubNonBroadcast(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
@@ -3755,31 +3160,6 @@ inline void LstmCell(
output_state_map.tanh();
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
- const float* prev_activ_data,
- const Dims<4>& prev_activ_dims, const float* weights_data,
- const Dims<4>& weights_dims, const float* bias_data,
- const Dims<4>& bias_dims, const float* prev_state_data,
- const Dims<4>& prev_state_dims, float* output_state_data,
- const Dims<4>& output_state_dims, float* output_activ_data,
- const Dims<4>& output_activ_dims, float* concat_temp_data,
- const Dims<4>& concat_temp_dims, float* activ_temp_data,
- const Dims<4>& activ_temp_dims) {
- tflite::LstmCellParams op_params;
- // Float LSTM cell does not need parameters to be set: leave untouched.
-
- LstmCell(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(prev_activ_dims), prev_activ_data,
- DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
- bias_data, DimsToShape(prev_state_dims), prev_state_data,
- DimsToShape(output_state_dims), output_state_data,
- DimsToShape(output_activ_dims), output_activ_data,
- DimsToShape(concat_temp_dims), concat_temp_data,
- DimsToShape(activ_temp_dims), activ_temp_data);
-}
-
// Quantized LSTM cell. Currently just a copy of the reference impl in
// reference_ops.h. See the big function comment there, not replicating it
// here.
@@ -3801,11 +3181,11 @@ inline void LstmCell(
uint8* concat_temp_data_uint8,
const RuntimeShape& unextended_activ_temp_shape,
int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) {
+ gemmlowp::ScopedProfilingLabel label(
+ "LstmCell/quantized (8bit external, 16bit internal)");
int32 weights_zero_point = params.weights_zero_point;
int32 accum_multiplier = params.accum_multiplier;
int accum_shift = params.accum_shift;
- gemmlowp::ScopedProfilingLabel label(
- "LstmCell/quantized (8bit external, 16bit internal)");
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
@@ -4070,37 +3450,6 @@ inline void LstmCell(
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <int StateIntegerBits>
-void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
- const uint8* prev_activ_data_uint8,
- const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
- const Dims<4>& weights_dims, const int32* bias_data_int32,
- const Dims<4>& bias_dims, const int16* prev_state_data_int16,
- const Dims<4>& prev_state_dims, int16* output_state_data_int16,
- const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
- const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
- const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
- const Dims<4>& activ_temp_dims, int32 weights_zero_point,
- int32 accum_multiplier, int accum_shift,
- gemmlowp::GemmContext* gemm_context) {
- tflite::LstmCellParams op_params;
- op_params.weights_zero_point = weights_zero_point;
- op_params.accum_multiplier = accum_multiplier;
- op_params.accum_shift = accum_shift;
-
- LstmCell<StateIntegerBits>(
- op_params, DimsToShape(input_dims), input_data_uint8,
- DimsToShape(prev_activ_dims), prev_activ_data_uint8,
- DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
- bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
- DimsToShape(output_state_dims), output_state_data_int16,
- DimsToShape(output_activ_dims), output_activ_data_uint8,
- DimsToShape(concat_temp_dims), concat_temp_data_uint8,
- DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context);
-}
-
inline int NodeOffset(int b, int h, int w, int height, int width) {
return (b * height + h) * width + w;
}
@@ -4560,16 +3909,6 @@ inline void Softmax(const SoftmaxParams& params,
out_mat.array().rowwise() *= scale;
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
- float beta, float* output_data,
- const RuntimeShape& output_shape) {
- SoftmaxParams params;
- params.beta = beta;
- Softmax(params, input_shape, input_data, output_shape, output_data);
-}
-
inline void Softmax(const SoftmaxParams& params,
const RuntimeShape& input_shape, const uint8* input_data,
const RuntimeShape& output_shape, uint8* output_data) {
@@ -4781,19 +4120,6 @@ inline void Softmax(const SoftmaxParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_beta_multiplier, int32 input_beta_left_shift,
- int diff_min, uint8* output_data,
- const RuntimeShape& output_shape) {
- SoftmaxParams params;
- params.input_multiplier = input_beta_multiplier;
- params.input_left_shift = input_beta_left_shift;
- params.diff_min = diff_min;
- Softmax(params, input_shape, input_data, output_shape, output_data);
-}
-
// TODO(myenik): This is the same as the reference implementation, not actually
// optimized yet.
inline void LogSoftmax(const SoftmaxParams& params,
@@ -4831,15 +4157,6 @@ inline void LogSoftmax(const SoftmaxParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy
-inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
- SoftmaxParams params;
- // No params currently used for float LogSoftmax.
- LogSoftmax(params, input_shape, input_data, output_shape, output_data);
-}
-
template <int OutputIntegerBits, int InputIntegerBits>
inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1_impl(
@@ -5020,7 +4337,7 @@ inline void LogSoftmax(const SoftmaxParams& params,
std::max(diff_min - 1, // Note use of > below instead of >= above.
MultiplyByQuantizedMultiplierSmallerThanOneExp(
rescaled_diff_min, reverse_scaling_divisor,
- kReverseShift * reverse_scaling_right_shift));
+ -reverse_scaling_right_shift));
for (int c = 0; c < depth; ++c) {
int32 input_diff = static_cast<int32>(block_input_data[c]) - max_in_row;
@@ -5044,24 +4361,7 @@ inline void LogSoftmax(const SoftmaxParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_multiplier, int32 input_left_shift,
- int32 reverse_scaling_divisor,
- int32 reverse_scaling_right_shift, int diff_min,
- uint8* output_data, const RuntimeShape& output_shape) {
- SoftmaxParams params;
- params.input_multiplier = input_multiplier;
- params.input_left_shift = input_left_shift;
- params.reverse_scaling_divisor = reverse_scaling_divisor;
- params.reverse_scaling_right_shift = reverse_scaling_right_shift;
- params.diff_min = diff_min;
- LogSoftmax(params, input_shape, input_data, output_shape, output_data);
-}
-
-inline void Logistic(const LogisticParams& params,
- const RuntimeShape& input_shape, const float* input_data,
+inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Logistic");
auto input_map = MapAsVector(input_data, input_shape);
@@ -5070,13 +4370,13 @@ inline void Logistic(const LogisticParams& params,
input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op<float>());
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
- LogisticParams params;
- // No params currently needed by float Logistic.
- Logistic(params, input_shape, input_data, output_shape, output_data);
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Logistic(input_shape, input_data, output_shape, output_data);
}
inline void Logistic(const LogisticParams& params,
@@ -5219,20 +4519,6 @@ inline void Logistic(const LogisticParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_zero_point, int32 input_range_radius,
- int32 input_multiplier, int input_left_shift,
- uint8* output_data, const RuntimeShape& output_shape) {
- LogisticParams params;
- params.input_zero_point = input_zero_point;
- params.input_range_radius = input_range_radius;
- params.input_multiplier = input_multiplier;
- params.input_left_shift = input_left_shift;
- Logistic(params, input_shape, input_data, output_shape, output_data);
-}
-
inline void Logistic(const LogisticParams& params,
const RuntimeShape& input_shape, const int16* input_data,
const RuntimeShape& output_shape, int16* output_data) {
@@ -5294,40 +4580,21 @@ inline void Logistic(const LogisticParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy version.
-inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
- const RuntimeShape& output_shape, int16* output_data) {
- LogisticParams params;
- // No params currently needed by int16 Logistic.
- Logistic(params, input_shape, input_data, output_shape, output_data);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy version.
-inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
- int16* output_data, const RuntimeShape& output_shape) {
- LogisticParams params;
- // No params currently needed by int16 Logistic.
- Logistic(params, input_shape, input_data, output_shape, output_data);
-}
-
-inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
- const float* input_data, const RuntimeShape& output_shape,
- float* output_data) {
+inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Tanh");
auto input_map = MapAsVector(input_data, input_shape);
auto output_map = MapAsVector(output_data, output_shape);
output_map.array() = input_map.array().tanh();
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
- TanhParams params;
- // Currently no params needed for float Tanh.
- Tanh(params, input_shape, input_data, output_shape, output_data);
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Tanh(input_shape, input_data, output_shape, output_data);
}
inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
@@ -5480,20 +4747,6 @@ inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_zero_point, int32 input_range_radius,
- int32 input_multiplier, int input_left_shift,
- uint8* output_data, const RuntimeShape& output_shape) {
- TanhParams params;
- params.input_zero_point = input_zero_point;
- params.input_range_radius = input_range_radius;
- params.input_multiplier = input_multiplier;
- params.input_left_shift = input_left_shift;
- Tanh(params, input_shape, input_data, output_shape, output_data);
-}
-
inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
const int16* input_data, const RuntimeShape& output_shape,
int16* output_data) {
@@ -5595,16 +4848,6 @@ inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
- int input_left_shift, int16* output_data,
- const RuntimeShape& output_shape) {
- TanhParams params;
- params.input_left_shift = input_left_shift;
- Tanh(params, input_shape, input_data, output_shape, output_data);
-}
-
template <typename SrcT, typename DstT>
inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
const RuntimeShape& output_shape, DstT* output_data) {
@@ -6382,6 +5625,16 @@ void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
output_map.array() = input1_map.array().min(min_value);
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape,
@@ -6393,6 +5646,16 @@ void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
output_map.array() = input1_map.array().max(max_value);
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
const RuntimeShape& input_shape, const T* input_data,
@@ -6467,27 +5730,6 @@ void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T>
-void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
- const Dims<4>& filter_dims, int stride_width,
- int stride_height, int pad_width, int pad_height,
- const Dims<4>& output_dims, uint8 zero_byte,
- T* im2col_data) {
- tflite::ConvParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
-
- TransposeIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
- DimsToShape(filter_dims), DimsToShape(output_dims),
- im2col_data);
-}
-
inline void TransposeConv(
const ConvParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& filter_shape,
@@ -6511,27 +5753,6 @@ inline void TransposeConv(
Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- tflite::ConvParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
-
- TransposeConv(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
- output_data, DimsToShape(im2col_dims), im2col_data);
-}
-
} // namespace optimized_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
index bb5d590775..11224270a4 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
@@ -22,25 +22,36 @@ limitations under the License.
namespace tflite {
namespace reference_ops {
-inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height,
- int dilation_width_factor, int dilation_height_factor,
- int pad_width, int pad_height, int depth_multiplier,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+inline void DepthwiseConv(
+ const DepthwiseParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int depth_multiplier = params.depth_multiplier;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = input_shape.Dims(3);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
for (int b = 0; b < batches; ++b) {
for (int out_y = 0; out_y < output_height; ++out_y) {
@@ -61,18 +72,18 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
(in_y < input_height)) {
float input_value =
- input_data[Offset(input_dims, ic, in_x, in_y, b)];
+ input_data[Offset(input_shape, b, in_y, in_x, ic)];
float filter_value = filter_data[Offset(
- filter_dims, oc, filter_x, filter_y, 0)];
+ filter_shape, 0, filter_y, filter_x, oc)];
total += (input_value * filter_value);
}
}
}
float bias_value = 0.0f;
if (bias_data) {
- bias_value = bias_data[Offset(bias_dims, oc, 0, 0, 0)];
+ bias_value = bias_data[oc];
}
- output_data[Offset(output_dims, oc, out_x, out_y, b)] =
+ output_data[Offset(output_shape, b, out_y, out_x, oc)] =
ActivationFunctionWithMinMax(total + bias_value,
output_activation_min,
output_activation_max);
@@ -83,48 +94,6 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
}
}
-inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
- DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride_width, stride_height, 1, 1, pad_width,
- pad_height, depth_multiplier, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// Legacy, for compatibility with old checked-in code.
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride_width, stride_height, pad_width, pad_height,
- depth_multiplier, output_activation_min, output_activation_max,
- output_data, output_dims);
-}
-
-// Legacy, for compatibility with old checked-in code.
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, int depth_multiplier,
- float* output_data, const Dims<4>& output_dims) {
- DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride, stride, pad_width, pad_height,
- depth_multiplier, output_data, output_dims);
-}
-
} // end namespace reference_ops
} // end namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
index 5e3e8997fc..eab28e6c84 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
@@ -18,7 +18,6 @@ limitations under the License.
#include <algorithm>
#include "fixedpoint/fixedpoint.h"
-#include "public/gemmlowp.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
@@ -26,27 +25,42 @@ limitations under the License.
namespace tflite {
namespace reference_ops {
-inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height,
- int dilation_width_factor, int dilation_height_factor,
- int pad_width, int pad_height, int depth_multiplier,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+inline void DepthwiseConv(
+ const DepthwiseParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int depth_multiplier = params.depth_multiplier;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = input_shape.Dims(3);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
for (int b = 0; b < batches; ++b) {
for (int out_y = 0; out_y < output_height; ++out_y) {
@@ -67,23 +81,23 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
(in_y < input_height)) {
int32 input_val =
- input_data[Offset(input_dims, ic, in_x, in_y, b)];
- int32 filter_val = filter_data[Offset(filter_dims, oc,
- filter_x, filter_y, 0)];
+ input_data[Offset(input_shape, b, in_y, in_x, ic)];
+ int32 filter_val = filter_data[Offset(
+ filter_shape, 0, filter_y, filter_x, oc)];
acc +=
(filter_val + filter_offset) * (input_val + input_offset);
}
}
}
if (bias_data) {
- acc += bias_data[Offset(bias_dims, oc, 0, 0, 0)];
+ acc += bias_data[oc];
}
acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
+ output_shift);
acc += output_offset;
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
- output_data[Offset(output_dims, oc, out_x, out_y, b)] =
+ output_data[Offset(output_shape, b, out_y, out_x, oc)] =
static_cast<uint8>(acc);
}
}
@@ -92,66 +106,6 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride_width,
- stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
- output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
-// Legacy, for compatibility with old checked-in code.
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride_width,
- stride_height, pad_width, pad_height, depth_multiplier,
- output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
-// Legacy, for compatibility with old checked-in code.
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, int depth_multiplier,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
- filter_dims, filter_offset, bias_data, bias_dims, stride,
- stride, pad_width, pad_height, depth_multiplier,
- output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
} // end namespace reference_ops
} // end namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h b/tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h
new file mode 100644
index 0000000000..3c7fd29256
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h
@@ -0,0 +1,326 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_
+
+#include "fixedpoint/fixedpoint.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/round.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace reference_ops {
+
+const int kReverseShift = -1;
+
+inline void FullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& weights_shape,
+ const float* weights_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ // TODO(benoitjacob): This really should be:
+ // const int batches = ArraySize(output_dims, 1);
+ // but the current --variable_batch hack consists in overwriting the 3rd
+ // dimension with the runtime batch size, as we don't keep track for each
+ // array of which dimension is the batch dimension in it.
+ const int output_dims_count = output_shape.DimensionsCount();
+ const int weights_dims_count = weights_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dims_count - 1);
+ const int output_depth = MatchingDim(weights_shape, weights_dims_count - 2,
+ output_shape, output_dims_count - 1);
+ const int accum_depth = weights_shape.Dims(weights_dims_count - 1);
+ for (int b = 0; b < batches; ++b) {
+ for (int out_c = 0; out_c < output_depth; ++out_c) {
+ float total = 0.f;
+ for (int d = 0; d < accum_depth; ++d) {
+ total += input_data[b * accum_depth + d] *
+ weights_data[out_c * accum_depth + d];
+ }
+ float bias_value = 0.0f;
+ if (bias_data) {
+ bias_value = bias_data[out_c];
+ }
+ output_data[out_c + output_depth * b] = ActivationFunctionWithMinMax(
+ total + bias_value, output_activation_min, output_activation_max);
+ }
+ }
+}
+
+inline void FullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data, void* gemm_context) {
+ (void)gemm_context; // only used in optimized code.
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ // TODO(benoitjacob): This really should be:
+ // const int batches = ArraySize(output_dims, 1);
+ // but the current --variable_batch hack consists in overwriting the 3rd
+ // dimension with the runtime batch size, as we don't keep track for each
+ // array of which dimension is the batch dimension in it.
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int filter_dim_count = filter_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+ const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
+ output_shape, output_dim_count - 1);
+ const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
+ for (int b = 0; b < batches; ++b) {
+ for (int out_c = 0; out_c < output_depth; ++out_c) {
+ int32 acc = 0;
+ for (int d = 0; d < accum_depth; ++d) {
+ int32 input_val = input_data[b * accum_depth + d];
+ int32 filter_val = filter_data[out_c * accum_depth + d];
+ acc += (filter_val + filter_offset) * (input_val + input_offset);
+ }
+ if (bias_data) {
+ acc += bias_data[out_c];
+ }
+ acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
+ acc += output_offset;
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_data[out_c + output_depth * b] = static_cast<uint8>(acc);
+ }
+ }
+}
+
+inline void FullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ int16* output_data, void* gemm_context) {
+ (void)gemm_context; // only used in optimized code.
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ TFLITE_DCHECK_EQ(output_offset, 0);
+ // TODO(benoitjacob): This really should be:
+ // const int batches = ArraySize(output_dims, 1);
+ // but the current --variable_batch hack consists in overwriting the 3rd
+ // dimension with the runtime batch size, as we don't keep track for each
+ // array of which dimension is the batch dimension in it.
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int filter_dim_count = filter_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+ const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
+ output_shape, output_dim_count - 1);
+ const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
+ for (int b = 0; b < batches; ++b) {
+ for (int out_c = 0; out_c < output_depth; ++out_c) {
+ // Internal accumulation.
+ // Initialize accumulator with the bias-value.
+ int32 accum = bias_data[out_c];
+ // Accumulation loop.
+ for (int d = 0; d < accum_depth; ++d) {
+ int16 input_val = input_data[b * accum_depth + d] + input_offset;
+ int16 filter_val = filter_data[out_c * accum_depth + d] + filter_offset;
+ accum += filter_val * input_val;
+ }
+ // Down-scale the final int32 accumulator to the scale used by our
+ // (16-bit, typically 3 integer bits) fixed-point format. The quantized
+ // multiplier and shift here have been pre-computed offline
+ // (e.g. by toco).
+ accum =
+ MultiplyByQuantizedMultiplier(accum, output_multiplier, output_shift);
+ // Saturate, cast to int16, and store to output array.
+ accum = std::max(accum, output_activation_min - output_offset);
+ accum = std::min(accum, output_activation_max - output_offset);
+ accum += output_offset;
+ output_data[out_c + output_depth * b] = accum;
+ }
+ }
+}
+
+inline void ShuffledFullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& weights_shape,
+ const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ int16* output_data, uint8* shuffled_input_workspace_data,
+ void* gemm_context) {
+ (void)gemm_context; // only used in optimized code.
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+
+ TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+ // TODO(benoitjacob): This really should be:
+ // const int batches = ArraySize(output_dims, 1);
+ // but the current --variable_batch hack consists in overwriting the 3rd
+ // dimension with the runtime batch size, as we don't keep track for each
+ // array of which dimension is the batch dimension in it.
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+ const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
+ output_shape, output_dim_count - 1);
+ const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
+ TFLITE_DCHECK((accum_depth % 16) == 0);
+ TFLITE_DCHECK((output_depth % 4) == 0);
+
+ // Shuffling and xoring of input activations into the workspace buffer
+ uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data;
+ if (batches == 1) {
+ for (int i = 0; i < accum_depth; i++) {
+ shuffled_input_workspace_data[i] = input_data[i] ^ 0x80;
+ }
+ } else if (batches == 4) {
+ for (int c = 0; c < accum_depth; c += 16) {
+ for (int b = 0; b < 4; b++) {
+ const uint8* src_data_ptr = input_data + b * accum_depth + c;
+ for (int j = 0; j < 16; j++) {
+ uint8 src_val = *src_data_ptr++;
+ // Flip the sign bit, so that the kernel will only need to
+ // reinterpret these uint8 values as int8, getting for free the
+ // subtraction of the zero_point value 128.
+ uint8 dst_val = src_val ^ 0x80;
+ *shuffled_input_workspace_ptr++ = dst_val;
+ }
+ }
+ }
+ } else {
+ TFLITE_DCHECK(false);
+ return;
+ }
+
+ // Actual computation
+ if (batches == 1) {
+ int16* output_ptr = output_data;
+ // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
+ // so that just reinterpreting them as int8 values is equivalent to
+ // subtracting 128 from them, thus implementing for free the subtraction of
+ // the zero_point value 128.
+ const int8* shuffled_weights_ptr =
+ reinterpret_cast<const int8*>(shuffled_weights_data);
+ // Likewise, we preshuffled and pre-xored the input data above.
+ const int8* shuffled_input_data =
+ reinterpret_cast<const int8*>(shuffled_input_workspace_data);
+ for (int c = 0; c < output_depth; c += 4) {
+ // Internal accumulation.
+ // Initialize accumulator with the bias-value.
+ int32 accum[4] = {0};
+ // Accumulation loop.
+ for (int d = 0; d < accum_depth; d += 16) {
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 16; j++) {
+ int8 input_val = shuffled_input_data[d + j];
+ int8 weights_val = *shuffled_weights_ptr++;
+ accum[i] += weights_val * input_val;
+ }
+ }
+ }
+ for (int i = 0; i < 4; i++) {
+ // Add bias value
+ int32 acc = accum[i] + bias_data[c + i];
+ // Down-scale the final int32 accumulator to the scale used by our
+ // (16-bit, typically 3 integer bits) fixed-point format. The quantized
+ // multiplier and shift here have been pre-computed offline
+ // (e.g. by toco).
+ acc =
+ MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
+ // Saturate, cast to int16, and store to output array.
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_ptr[c + i] = acc;
+ }
+ }
+ } else if (batches == 4) {
+ int16* output_ptr = output_data;
+ // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
+ // so that just reinterpreting them as int8 values is equivalent to
+ // subtracting 128 from them, thus implementing for free the subtraction of
+ // the zero_point value 128.
+ const int8* shuffled_weights_ptr =
+ reinterpret_cast<const int8*>(shuffled_weights_data);
+ // Likewise, we preshuffled and pre-xored the input data above.
+ const int8* shuffled_input_data =
+ reinterpret_cast<const int8*>(shuffled_input_workspace_data);
+ for (int c = 0; c < output_depth; c += 4) {
+ const int8* shuffled_input_ptr = shuffled_input_data;
+ // Accumulation loop.
+ // Internal accumulation.
+ // Initialize accumulator with the bias-value.
+ int32 accum[4][4];
+ for (int i = 0; i < 4; i++) {
+ for (int b = 0; b < 4; b++) {
+ accum[i][b] = 0;
+ }
+ }
+ for (int d = 0; d < accum_depth; d += 16) {
+ for (int i = 0; i < 4; i++) {
+ for (int b = 0; b < 4; b++) {
+ for (int j = 0; j < 16; j++) {
+ int8 input_val = shuffled_input_ptr[16 * b + j];
+ int8 weights_val = shuffled_weights_ptr[16 * i + j];
+ accum[i][b] += weights_val * input_val;
+ }
+ }
+ }
+ shuffled_input_ptr += 64;
+ shuffled_weights_ptr += 64;
+ }
+ for (int i = 0; i < 4; i++) {
+ for (int b = 0; b < 4; b++) {
+ // Add bias value
+ int32 acc = accum[i][b] + bias_data[c + i];
+ // Down-scale the final int32 accumulator to the scale used by our
+ // (16-bit, typically 3 integer bits) fixed-point format. The
+ // quantized multiplier and shift here have been pre-computed offline
+ // (e.g. by toco).
+ acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
+ output_shift);
+ // Saturate, cast to int16, and store to output array.
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_ptr[b * output_depth + c + i] = acc;
+ }
+ }
+ }
+ } else {
+ TFLITE_DCHECK(false);
+ return;
+ }
+}
+
+} // namespace reference_ops
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
index 683ccdc74d..be99240b1f 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
@@ -19,6 +19,8 @@ limitations under the License.
#include <sys/types.h>
#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
@@ -26,6 +28,1070 @@ namespace tflite {
namespace reference_ops {
+static constexpr int kDepthwiseReverseShift = -1;
+
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthwiseParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data);
+}
+
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, 1, 1, pad_width,
+ pad_height, depth_multiplier, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, pad_width, pad_height,
+ depth_multiplier, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ float* output_data, const Dims<4>& output_dims) {
+ DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride, stride, pad_width, pad_height,
+ depth_multiplier, output_data, output_dims);
+}
+
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthwiseParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kDepthwiseReverseShift * output_shift;
+
+ DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data);
+}
+
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
+ filter_dims, filter_offset, bias_data, bias_dims, stride,
+ stride, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+inline void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims,
+ float* im2col_data, const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+ filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride_width,
+ int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ float* output_data, const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
+ stride_width, stride_height, dilation_width_factor,
+ dilation_height_factor, pad_width, pad_height, output_activation_min,
+ output_activation_max, output_data, output_dims, im2col_data,
+ im2col_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
+ stride_width, stride_height, 1, 1, pad_width, pad_height,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
+ output_dims, im2col_data, im2col_dims);
+}
+
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims,
+ uint8* im2col_data, const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+ filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data, gemm_context);
+}
+
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
+ pad_width, pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width, stride_height,
+ pad_width, pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
+ Conv<Ac>(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride, stride, pad_width,
+ pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims, im2col_data, im2col_dims, gemm_context);
+}
+
+inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+
+ TransposeConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::FullyConnectedParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(weights_dims), weights_data,
+ DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data, const Dims<4>& weights_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
+ bias_dims, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data,
+ gemm_context);
+}
+
+inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, int16* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data,
+ gemm_context);
+}
+
+inline void ShuffledFullyConnected(
+ const uint8* input_data, const Dims<4>& input_dims,
+ const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
+ const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
+ int output_shift, int32 output_activation_min, int32 output_activation_max,
+ int16* output_data, const Dims<4>& output_dims,
+ uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(weights_dims), shuffled_weights_data,
+ DimsToShape(bias_dims), bias_data,
+ DimsToShape(output_dims), output_data,
+ shuffled_input_workspace_data, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, output_offset,
+ output_multiplier, output_shift, output_activation_min,
+ output_activation_max, output_data, output_dims, gemm_context);
+}
+
+inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
+ const float* prev_activ_data,
+ const Dims<4>& prev_activ_dims, const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims, const float* prev_state_data,
+ const Dims<4>& prev_state_dims, float* output_state_data,
+ const Dims<4>& output_state_dims, float* output_activ_data,
+ const Dims<4>& output_activ_dims, float* concat_temp_data,
+ const Dims<4>& concat_temp_dims, float* activ_temp_data,
+ const Dims<4>& activ_temp_dims) {
+ tflite::LstmCellParams op_params;
+ // Float LSTM cell does not need parameters to be set: leave untouched.
+
+ LstmCell(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(prev_activ_dims), prev_activ_data,
+ DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(prev_state_dims), prev_state_data,
+ DimsToShape(output_state_dims), output_state_data,
+ DimsToShape(output_activ_dims), output_activ_data,
+ DimsToShape(concat_temp_dims), concat_temp_data,
+ DimsToShape(activ_temp_dims), activ_temp_data);
+}
+
+template <int StateIntegerBits>
+void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
+ const uint8* prev_activ_data_uint8,
+ const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
+ const Dims<4>& weights_dims, const int32* bias_data_int32,
+ const Dims<4>& bias_dims, const int16* prev_state_data_int16,
+ const Dims<4>& prev_state_dims, int16* output_state_data_int16,
+ const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
+ const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
+ const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
+ const Dims<4>& activ_temp_dims, int32 weights_zero_point,
+ int32 accum_multiplier, int accum_shift,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::LstmCellParams op_params;
+ op_params.weights_zero_point = weights_zero_point;
+ op_params.accum_multiplier = accum_multiplier;
+ op_params.accum_shift = accum_shift;
+
+ LstmCell<StateIntegerBits>(
+ op_params, DimsToShape(input_dims), input_data_uint8,
+ DimsToShape(prev_activ_dims), prev_activ_data_uint8,
+ DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
+ bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
+ DimsToShape(output_state_dims), output_state_data_int16,
+ DimsToShape(output_activ_dims), output_activ_data_uint8,
+ DimsToShape(concat_temp_dims), concat_temp_data_uint8,
+ DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context);
+}
+
+template <typename T>
+void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastDiv4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void Div(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ Div(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+inline void Concatenation(int concat_dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ // For now we don't have a model with a Concatenation with fused activation.
+ TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
+
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
+ }
+ tflite::ConcatenationParams op_params;
+ op_params.axis = 3 - concat_dim;
+ op_params.inputs_count = inputs_count;
+
+ Concatenation(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void Concatenation(int concat_dim, const uint8* const* input_data,
+ const Dims<4>* const* input_dims,
+ const int32* input_zeropoint,
+ const float* input_scale, int inputs_count,
+ uint8* output_data, const Dims<4>& output_dims,
+ const int32 output_zeropoint,
+ const float output_scale) {
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
+ }
+ tflite::ConcatenationParams op_params;
+ op_params.axis = 3 - concat_dim;
+ op_params.input_zeropoint = input_zeropoint;
+ op_params.input_scale = input_scale;
+ op_params.inputs_count = inputs_count;
+ op_params.output_zeropoint = output_zeropoint;
+ op_params.output_scale = output_scale;
+
+ ConcatenationWithScaling(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+void DepthConcatenation(const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ // For now we don't have a model with a Concatenation with fused activation.
+ TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
+
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
+ }
+ tflite::ConcatenationParams op_params;
+ op_params.inputs_count = inputs_count;
+
+ DepthConcatenation(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename Scalar>
+void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
+ int axis, int outputs_count, Scalar* const* output_data,
+ const Dims<4>* const* output_dims) {
+ std::vector<RuntimeShape> output_shapes(outputs_count);
+ std::vector<const RuntimeShape*> output_shapes_indirect(outputs_count);
+ for (int i = 0; i < outputs_count; ++i) {
+ ShapeFromDims(*output_dims[i], &output_shapes[i]);
+ output_shapes_indirect[i] = &output_shapes[i];
+ }
+ tflite::SplitParams op_params;
+ op_params.axis = 3 - axis;
+ op_params.num_split = outputs_count;
+
+ Split(op_params, DimsToShape(input_dims), input_data,
+ output_shapes_indirect.data(), output_data);
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
+ int outputs_count, Scalar* const* output_data,
+ const Dims<4>* const* output_dims) {
+ TFLITE_DCHECK_GE(outputs_count, 1);
+ for (int i = 0; i < outputs_count; i++) {
+ /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3);
+ /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
+ /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
+ }
+ // For now we don't have a model with a Split with fused activation.
+ TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
+
+ TensorFlowSplit(input_data, input_dims, /*axis=*/0, outputs_count,
+ output_data, output_dims);
+}
+
+inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
+ float beta, float* output_data,
+ const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.beta = beta;
+ Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_beta_multiplier, int32 input_beta_left_shift,
+ int diff_min, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.input_multiplier = input_beta_multiplier;
+ params.input_left_shift = input_beta_left_shift;
+ params.diff_min = diff_min;
+ Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ // No params currently used for float LogSoftmax.
+ LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_multiplier, int32 input_left_shift,
+ int32 reverse_scaling_divisor,
+ int32 reverse_scaling_right_shift, int diff_min,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ params.reverse_scaling_divisor = reverse_scaling_divisor;
+ params.reverse_scaling_right_shift = reverse_scaling_right_shift;
+ params.diff_min = diff_min;
+ LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ LogisticParams params;
+ params.input_zero_point = input_zero_point;
+ params.input_range_radius = input_range_radius;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+ const RuntimeShape& output_shape, int16* output_data) {
+ LogisticParams params;
+ // No params currently needed by int16 Logistic.
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ TanhParams params;
+ params.input_zero_point = input_zero_point;
+ params.input_range_radius = input_range_radius;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
+ int input_left_shift, int16* output_data,
+ const RuntimeShape& output_shape) {
+ TanhParams params;
+ params.input_left_shift = input_left_shift;
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
+ int32 zero_point, double scale, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DequantizationParams op_params;
+ op_params.zero_point = zero_point;
+ op_params.scale = scale;
+
+ Dequantize(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
+ float rmin, float rmax, int num_bits, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::FakeQuantParams op_params;
+ op_params.num_bits = num_bits;
+ op_params.minmax.min = rmin;
+ op_params.minmax.max = rmax;
+
+ FakeQuant(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void Gather(const T* input_data, const Dims<4>& input_dims,
+ int input_rank, const int32* coords_data,
+ const Dims<4>& coords_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::GatherParams op_params;
+ op_params.input_rank = input_rank;
+
+ Gather(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(coords_dims), coords_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline uint32 LegacyReverseBits32(uint32 n) {
+ n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1);
+ n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2);
+ n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4);
+ return (((n & 0xFF) << 24) | ((n & 0xFF00) << 8) | ((n & 0xFF0000) >> 8) |
+ ((n & 0xFF000000) >> 24));
+}
+
+inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) {
+ TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
+ TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
+
+ std::reverse(p->start_indices, p->start_indices + p->start_indices_count);
+ std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count);
+ std::reverse(p->strides, p->strides + p->strides_count);
+
+ p->begin_mask = LegacyReverseBits32(static_cast<uint32>(p->begin_mask)) >>
+ (32 - p->start_indices_count);
+ p->ellipsis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->ellipsis_mask)) >>
+ (32 - p->start_indices_count);
+ p->end_mask = LegacyReverseBits32(static_cast<uint32>(p->end_mask)) >>
+ (32 - p->start_indices_count);
+ p->new_axis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->new_axis_mask)) >>
+ (32 - p->start_indices_count);
+ p->shrink_axis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->shrink_axis_mask)) >>
+ (32 - p->start_indices_count);
+}
+
+template <typename T>
+inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
+ int begin_mask, int end_mask, int shrink_axis_mask,
+ const std::vector<int>& start_indices,
+ const std::vector<int>& stop_indices,
+ const std::vector<int>& strides, T* output_data,
+ const Dims<4>& output_dims) {
+ TFLITE_DCHECK_EQ(start_indices.size(), 4);
+ auto op_params = strided_slice::BuildStridedSliceParams(
+ begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
+ strides);
+ StridedSliceReverseIndices(&op_params);
+
+ StridedSlice(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void Mean(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& reduction_indices, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::MeanParams op_params;
+ op_params.axis_count = reduction_indices.size();
+ for (int i = 0; i < op_params.axis_count; ++i) {
+ op_params.axis[i] = reduction_indices[op_params.axis_count - 1 - i];
+ }
+
+ Mean(op_params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void Transpose(const T* input, const Dims<4>& input_dims, T* output,
+ const Dims<4>& output_dims, const int* permuted_axes) {
+ TransposeParams params;
+ params.perm_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ params.perm[i] = 3 - permuted_axes[3 - i];
+ }
+ Transpose(params, DimsToShape(input_dims), input, DimsToShape(output_dims),
+ output);
+}
+
+template <typename T, ComparisonFn<T> F>
+inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ bool* output_data, const Dims<4>& output_dims) {
+ ComparisonParams op_params;
+ // No parameters needed.
+ ComparisonImpl<T, F>(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T, ComparisonFn<int32> F>
+inline void Comparison(int left_shift, const T* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const T* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, bool* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ComparisonParams op_params;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.input2_shift = kReverseShift * input2_shift;
+
+ ComparisonWithScaling<T, F>(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T, ComparisonFn<T> F>
+inline void BroadcastComparison(const T* input1_data,
+ const Dims<4>& input1_dims,
+ const T* input2_data,
+ const Dims<4>& input2_dims, bool* output_data,
+ const Dims<4>& output_dims) {
+ ComparisonParams op_params;
+ // No parameters needed.
+ BroadcastComparison4DSlowImpl<T, F>(op_params, DimsToShape(input1_dims),
+ input1_data, DimsToShape(input2_dims),
+ input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T, ComparisonFn<int32> F>
+inline void BroadcastComparison(int left_shift, const T* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const T* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 input2_multiplier, int input2_shift,
+ bool* output_data, const Dims<4>& output_dims) {
+ ComparisonParams op_params;
+
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.input2_shift = kReverseShift * input2_shift;
+
+ BroadcastComparison4DSlowWithScaling<T, F>(
+ op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+#define TFLITE_LEGACY_COMPARISON_OP(name) \
+ template <typename T> \
+ inline void name(const T* input1_data, const Dims<4>& input1_dims, \
+ const T* input2_data, const Dims<4>& input2_dims, \
+ bool* output_data, const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label(#name); \
+ Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
+ input2_dims, output_data, output_dims); \
+ } \
+ template <typename T> \
+ inline void name( \
+ int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
+ int32 input1_offset, int32 input1_multiplier, int input1_shift, \
+ const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
+ int32 input2_multiplier, int input2_shift, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
+ Comparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
+ input1_offset, input1_multiplier, input1_shift, \
+ input2_data, input2_dims, input2_offset, \
+ input2_multiplier, input2_shift, output_data, \
+ output_dims); \
+ } \
+ template <typename T> \
+ inline void Broadcast##name( \
+ const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
+ const Dims<4>& input2_dims, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \
+ BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
+ input2_dims, output_data, output_dims); \
+ } \
+ template <typename T> \
+ inline void Broadcast##name( \
+ int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
+ int32 input1_offset, int32 input1_multiplier, int input1_shift, \
+ const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
+ int32 input2_multiplier, int input2_shift, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \
+ BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
+ input1_offset, input1_multiplier, \
+ input1_shift, input2_data, input2_dims, \
+ input2_offset, input2_multiplier, \
+ input2_shift, output_data, output_dims); \
+ }
+TFLITE_LEGACY_COMPARISON_OP(Equal);
+TFLITE_LEGACY_COMPARISON_OP(NotEqual);
+TFLITE_LEGACY_COMPARISON_OP(Greater);
+TFLITE_LEGACY_COMPARISON_OP(GreaterEqual);
+TFLITE_LEGACY_COMPARISON_OP(Less);
+TFLITE_LEGACY_COMPARISON_OP(LessEqual);
+#undef TFLITE_LEGACY_COMPARISON_OP
+
+template <typename D, typename T>
+inline void Select(const D* input_condition_data,
+ const Dims<4>& input_condition_dims, const T* input_x_data,
+ const Dims<4>& input_x_dims, const T* input_y_data,
+ const Dims<4>& input_y_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ Select(DimsToShape(input_condition_dims), input_condition_data,
+ DimsToShape(input_x_dims), input_x_data, DimsToShape(input_y_dims),
+ input_y_data, DimsToShape(output_dims), output_data);
+}
+
+template <typename D, typename T>
+inline void RankOneSelect(const D* input_condition_data,
+ const Dims<4>& input_condition_dims,
+ const T* input_x_data, const Dims<4>& input_x_dims,
+ const T* input_y_data, const Dims<4>& input_y_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ RankOneSelect(DimsToShape(input_condition_dims), input_condition_data,
+ DimsToShape(input_x_dims), input_x_data,
+ DimsToShape(input_y_dims), input_y_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T, typename TI>
+inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
+ const T* values, T default_value, T* output_data,
+ const Dims<4>& output_dims, bool value_is_scalar) {
+ SparseToDense(indices, values, default_value, value_is_scalar,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename Scalar>
+void Pack(int dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
+ }
+ tflite::PackParams op_params;
+ op_params.axis = 3 - dim;
+ op_params.inputs_count = inputs_count;
+
+ Pack(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename Scalar>
+void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims,
+ int dimensions, int outputs_count, Scalar* const* output_datas,
+ const Dims<4>& output_dims) {
+ tflite::UnpackParams op_params;
+ op_params.axis = 3 - axis;
+ op_params.num_split = outputs_count;
+
+ Unpack(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_datas);
+}
+
+template <typename Scalar>
+void Pack(int dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, const int32* input_zeropoint,
+ const float* input_scale, int inputs_count, Scalar* output_data,
+ const Dims<4>& output_dims, const int32 output_zeropoint,
+ const float output_scale) {
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
+ }
+ tflite::PackParams op_params;
+ op_params.axis = 3 - dim;
+ op_params.input_zeropoint = input_zeropoint;
+ op_params.input_scale = input_scale;
+ op_params.inputs_count = inputs_count;
+ op_params.output_zeropoint = output_zeropoint;
+ op_params.output_scale = output_scale;
+
+ PackWithScaling(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
template <FusedActivationFunctionType Ac>
void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
float* output_data, const RuntimeShape& output_shape) {
@@ -342,7 +1408,6 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims), output_data);
}
-// Legacy.
// Transitional version that will be moved shortly to legacy_reference_ops, as
// part of RuntimeShape revisions.
inline void BroadcastMul4DSlow(const uint8* input1_data,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
index 77e60adc18..70d25c4bd9 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -55,7 +55,7 @@ void PortableSymmetricQuantizeFloats(const float* values, const int size,
return;
}
*scaling_factor = range / kScale;
- const float scaling_factor_inv = 1.0f / *scaling_factor;
+ const float scaling_factor_inv = kScale / range;
for (int i = 0; i < size; ++i) {
const int32_t quantized_value =
static_cast<int32_t>(TfLiteRound(values[i] * scaling_factor_inv));
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 77927af227..59f17ae854 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -28,6 +28,8 @@ limitations under the License.
#include "public/gemmlowp.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/softmax.h"
#include "tensorflow/contrib/lite/kernels/internal/round.h"
#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
@@ -98,13 +100,6 @@ gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
namespace reference_ops {
-// TODO(b/80247582) Remove this constant.
-// This will be phased out as the shifts are revised with more thought. Use of a
-// constant enables us to track progress on this work.
-//
-// Used mainly to convert from old-style shifts (right) to new-style (left).
-static constexpr int kReverseShift = -1;
-
inline void ShapeFromDims(const tflite::Dims<4>& dims, RuntimeShape* shape) {
shape->BuildFrom(
{dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
@@ -163,28 +158,38 @@ SaturatingRoundingMultiplyByPOTParam(
SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
}
-inline void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims,
- float* im2col_data, const Dims<4>& im2col_dims) {
- (void)im2col_data; // only used in optimized code.
- (void)im2col_dims; // only used in optimized code.
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
- const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape,
+ float* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ (void)im2col_data; // only used in optimized code.
+ (void)im2col_shape; // only used in optimized code.
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
if (bias_data) {
- TFLITE_DCHECK_EQ(ArraySize(filter_dims, 3), ArraySize(bias_dims, 0));
- }
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+ }
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
@@ -202,11 +207,11 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
// use zero as a default value.
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
(in_y < input_height)) {
- float input_value = input_data[Offset(input_dims, in_channel,
- in_x, in_y, batch)];
+ float input_value = input_data[Offset(
+ input_shape, batch, in_y, in_x, in_channel)];
float filter_value =
- filter_data[Offset(filter_dims, in_channel, filter_x,
- filter_y, out_channel)];
+ filter_data[Offset(filter_shape, out_channel, filter_y,
+ filter_x, in_channel)];
total += (input_value * filter_value);
}
}
@@ -214,9 +219,9 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
}
float bias_value = 0.0f;
if (bias_data) {
- bias_value = bias_data[Offset(bias_dims, out_channel, 0, 0, 0)];
+ bias_value = bias_data[out_channel];
}
- output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] =
+ output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] =
ActivationFunctionWithMinMax(total + bias_value,
output_activation_min,
output_activation_max);
@@ -226,77 +231,45 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
}
}
-template <FusedActivationFunctionType Ac>
-void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride_width,
- int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- float* output_data, const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
- stride_width, stride_height, dilation_width_factor,
- dilation_height_factor, pad_width, pad_height, output_activation_min,
- output_activation_max, output_data, output_dims, im2col_data,
- im2col_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride_width,
- int stride_height, int pad_width, int pad_height, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
- stride_width, stride_height, 1, 1, pad_width, pad_height,
- output_activation_min, output_activation_max, output_data, output_dims,
- im2col_data, im2col_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
- output_dims, im2col_data, im2col_dims);
-}
-
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims,
- uint8* im2col_data, const Dims<4>& im2col_dims,
- gemmlowp::GemmContext* gemm_context) {
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data, const RuntimeShape& im2col_shape,
+ uint8* im2col_data, gemmlowp::GemmContext* gemm_context) {
(void)im2col_data; // only used in optimized code.
- (void)im2col_dims; // only used in optimized code.
+ (void)im2col_shape; // only used in optimized code.
(void)gemm_context; // only used in optimized code.
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
- const int output_depth =
- MatchingArraySize(filter_dims, 3, bias_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+ if (bias_data) {
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+ }
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
@@ -314,11 +287,11 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
// use zero as a default value.
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
(in_y < input_height)) {
- int32 input_val = input_data[Offset(input_dims, in_channel,
- in_x, in_y, batch)];
+ int32 input_val = input_data[Offset(input_shape, batch, in_y,
+ in_x, in_channel)];
int32 filter_val =
- filter_data[Offset(filter_dims, in_channel, filter_x,
- filter_y, out_channel)];
+ filter_data[Offset(filter_shape, out_channel, filter_y,
+ filter_x, in_channel)];
acc +=
(filter_val + filter_offset) * (input_val + input_offset);
}
@@ -326,14 +299,14 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
}
}
if (bias_data) {
- acc += bias_data[Offset(bias_dims, out_channel, 0, 0, 0)];
+ acc += bias_data[out_channel];
}
acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- kReverseShift * output_shift);
+ output_shift);
acc += output_offset;
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
- output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] =
+ output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] =
static_cast<uint8>(acc);
}
}
@@ -341,71 +314,6 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims,
- gemmlowp::GemmContext* gemm_context) {
- Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
- pad_width, pad_height, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data, output_dims,
- im2col_data, im2col_dims, gemm_context);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims,
- gemmlowp::GemmContext* gemm_context) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride_width, stride_height,
- pad_width, pad_height, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data, output_dims,
- im2col_data, im2col_dims, gemm_context);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
- Conv<Ac>(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride, stride, pad_width,
- pad_height, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data,
- output_dims, im2col_data, im2col_dims, gemm_context);
-}
-
template <typename T>
inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
const RuntimeShape& unextended_input_shape,
@@ -511,320 +419,6 @@ inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
}
}
-inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
- const float* weights_data,
- const Dims<4>& weights_dims, const float* bias_data,
- const Dims<4>& bias_dims,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
- // TODO(benoitjacob): This really should be:
- // const int batches = ArraySize(output_dims, 1);
- // but the current --variable_batch hack consists in overwriting the 3rd
- // dimension with the runtime batch size, as we don't keep track for each
- // array of which dimension is the batch dimension in it.
- const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
- ArraySize(output_dims, 3);
- const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(weights_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
- for (int b = 0; b < batches; ++b) {
- for (int out_c = 0; out_c < output_depth; ++out_c) {
- float total = 0.f;
- for (int d = 0; d < accum_depth; ++d) {
- total += input_data[b * accum_depth + d] *
- weights_data[out_c * accum_depth + d];
- }
- float bias_value = 0.0f;
- if (bias_data) {
- bias_value = bias_data[Offset(bias_dims, out_c, 0, 0, 0)];
- }
- output_data[out_c + output_depth * b] = ActivationFunctionWithMinMax(
- total + bias_value, output_activation_min, output_activation_max);
- }
- }
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void FullyConnected(const float* input_data, const Dims<4>& input_dims,
- const float* weights_data, const Dims<4>& weights_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
- bias_dims, output_activation_min, output_activation_max,
- output_data, output_dims);
-}
-
-inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
- (void)gemm_context; // only used in optimized code.
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- // TODO(benoitjacob): This really should be:
- // const int batches = ArraySize(output_dims, 1);
- // but the current --variable_batch hack consists in overwriting the 3rd
- // dimension with the runtime batch size, as we don't keep track for each
- // array of which dimension is the batch dimension in it.
- const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
- ArraySize(output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(filter_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
- for (int b = 0; b < batches; ++b) {
- for (int out_c = 0; out_c < output_depth; ++out_c) {
- int32 acc = 0;
- for (int d = 0; d < accum_depth; ++d) {
- int32 input_val = input_data[b * accum_depth + d];
- int32 filter_val = filter_data[out_c * accum_depth + d];
- acc += (filter_val + filter_offset) * (input_val + input_offset);
- }
- if (bias_data) {
- acc += bias_data[Offset(bias_dims, out_c, 0, 0, 0)];
- }
- acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- kReverseShift * output_shift);
- acc += output_offset;
- acc = std::max(acc, output_activation_min);
- acc = std::min(acc, output_activation_max);
- output_data[out_c + output_depth * b] = static_cast<uint8>(acc);
- }
- }
-}
-
-inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, int16* output_data,
- const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
- (void)gemm_context; // only used in optimized code.
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- TFLITE_DCHECK_EQ(output_offset, 0);
- // TODO(benoitjacob): This really should be:
- // const int batches = ArraySize(output_dims, 1);
- // but the current --variable_batch hack consists in overwriting the 3rd
- // dimension with the runtime batch size, as we don't keep track for each
- // array of which dimension is the batch dimension in it.
- const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
- ArraySize(output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(filter_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
- for (int b = 0; b < batches; ++b) {
- for (int out_c = 0; out_c < output_depth; ++out_c) {
- // Internal accumulation.
- // Initialize accumulator with the bias-value.
- int32 accum = bias_data[out_c];
- // Accumulation loop.
- for (int d = 0; d < accum_depth; ++d) {
- int16 input_val = input_data[b * accum_depth + d] + input_offset;
- int16 filter_val = filter_data[out_c * accum_depth + d] + filter_offset;
- accum += filter_val * input_val;
- }
- // Down-scale the final int32 accumulator to the scale used by our
- // (16-bit, typically 3 integer bits) fixed-point format. The quantized
- // multiplier and shift here have been pre-computed offline
- // (e.g. by toco).
- accum = MultiplyByQuantizedMultiplier(accum, output_multiplier,
- -output_shift);
- // Saturate, cast to int16, and store to output array.
- accum = std::max(accum, output_activation_min - output_offset);
- accum = std::min(accum, output_activation_max - output_offset);
- accum += output_offset;
- output_data[out_c + output_depth * b] = accum;
- }
- }
-}
-
-inline void ShuffledFullyConnected(
- const uint8* input_data, const Dims<4>& input_dims,
- const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
- const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- int16* output_data, const Dims<4>& output_dims,
- uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
- (void)gemm_context; // only used in optimized code.
-
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- // TODO(benoitjacob): This really should be:
- // const int batches = ArraySize(output_dims, 1);
- // but the current --variable_batch hack consists in overwriting the 3rd
- // dimension with the runtime batch size, as we don't keep track for each
- // array of which dimension is the batch dimension in it.
- const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
- ArraySize(output_dims, 3);
- const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(weights_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
- TFLITE_DCHECK((accum_depth % 16) == 0);
- TFLITE_DCHECK((output_depth % 4) == 0);
-
- // Shuffling and xoring of input activations into the workspace buffer
- uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data;
- if (batches == 1) {
- for (int i = 0; i < accum_depth; i++) {
- shuffled_input_workspace_data[i] = input_data[i] ^ 0x80;
- }
- } else if (batches == 4) {
- for (int c = 0; c < accum_depth; c += 16) {
- for (int b = 0; b < 4; b++) {
- const uint8* src_data_ptr = input_data + b * accum_depth + c;
- for (int j = 0; j < 16; j++) {
- uint8 src_val = *src_data_ptr++;
- // Flip the sign bit, so that the kernel will only need to
- // reinterpret these uint8 values as int8, getting for free the
- // subtraction of the zero_point value 128.
- uint8 dst_val = src_val ^ 0x80;
- *shuffled_input_workspace_ptr++ = dst_val;
- }
- }
- }
- } else {
- TFLITE_DCHECK(false);
- return;
- }
-
- // Actual computation
- if (batches == 1) {
- int16* output_ptr = output_data;
- // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
- // so that just reinterpreting them as int8 values is equivalent to
- // subtracting 128 from them, thus implementing for free the subtraction of
- // the zero_point value 128.
- const int8* shuffled_weights_ptr =
- reinterpret_cast<const int8*>(shuffled_weights_data);
- // Likewise, we preshuffled and pre-xored the input data above.
- const int8* shuffled_input_data =
- reinterpret_cast<const int8*>(shuffled_input_workspace_data);
- for (int c = 0; c < output_depth; c += 4) {
- // Internal accumulation.
- // Initialize accumulator with the bias-value.
- int32 accum[4] = {0};
- // Accumulation loop.
- for (int d = 0; d < accum_depth; d += 16) {
- for (int i = 0; i < 4; i++) {
- for (int j = 0; j < 16; j++) {
- int8 input_val = shuffled_input_data[d + j];
- int8 weights_val = *shuffled_weights_ptr++;
- accum[i] += weights_val * input_val;
- }
- }
- }
- for (int i = 0; i < 4; i++) {
- // Add bias value
- int acc = accum[i] + bias_data[c + i];
- // Down-scale the final int32 accumulator to the scale used by our
- // (16-bit, typically 3 integer bits) fixed-point format. The quantized
- // multiplier and shift here have been pre-computed offline
- // (e.g. by toco).
- acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
- // Saturate, cast to int16, and store to output array.
- acc = std::max(acc, output_activation_min);
- acc = std::min(acc, output_activation_max);
- output_ptr[c + i] = acc;
- }
- }
- } else if (batches == 4) {
- int16* output_ptr = output_data;
- // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
- // so that just reinterpreting them as int8 values is equivalent to
- // subtracting 128 from them, thus implementing for free the subtraction of
- // the zero_point value 128.
- const int8* shuffled_weights_ptr =
- reinterpret_cast<const int8*>(shuffled_weights_data);
- // Likewise, we preshuffled and pre-xored the input data above.
- const int8* shuffled_input_data =
- reinterpret_cast<const int8*>(shuffled_input_workspace_data);
- for (int c = 0; c < output_depth; c += 4) {
- const int8* shuffled_input_ptr = shuffled_input_data;
- // Accumulation loop.
- // Internal accumulation.
- // Initialize accumulator with the bias-value.
- int32 accum[4][4];
- for (int i = 0; i < 4; i++) {
- for (int b = 0; b < 4; b++) {
- accum[i][b] = 0;
- }
- }
- for (int d = 0; d < accum_depth; d += 16) {
- for (int i = 0; i < 4; i++) {
- for (int b = 0; b < 4; b++) {
- for (int j = 0; j < 16; j++) {
- int8 input_val = shuffled_input_ptr[16 * b + j];
- int8 weights_val = shuffled_weights_ptr[16 * i + j];
- accum[i][b] += weights_val * input_val;
- }
- }
- }
- shuffled_input_ptr += 64;
- shuffled_weights_ptr += 64;
- }
- for (int i = 0; i < 4; i++) {
- for (int b = 0; b < 4; b++) {
- // Add bias value
- int acc = accum[i][b] + bias_data[c + i];
- // Down-scale the final int32 accumulator to the scale used by our
- // (16-bit, typically 3 integer bits) fixed-point format. The
- // quantized multiplier and shift here have been pre-computed offline
- // (e.g. by toco).
- acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
- // Saturate, cast to int16, and store to output array.
- acc = std::max(acc, output_activation_min);
- acc = std::min(acc, output_activation_max);
- output_ptr[b * output_depth + c + i] = acc;
- }
- }
- }
- } else {
- TFLITE_DCHECK(false);
- return;
- }
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_dims, gemm_context);
-}
-
inline void Relu(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -945,6 +539,7 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
*output_inv_sqrt <<= -*output_shift;
*output_shift = 0;
}
+ // Convert right shift (right is positive) to left shift.
*output_shift *= kReverseShift;
}
@@ -1608,21 +1203,6 @@ void BroadcastDiv4DSlow(const ArithmeticParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4>.
-template <typename T>
-void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- tflite::ArithmeticParams op_params;
- SetActivationParams(output_activation_min, output_activation_max, &op_params);
-
- BroadcastDiv4DSlow(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data);
-}
-
template <typename T>
inline void Div(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const T* input1_data,
@@ -1641,21 +1221,6 @@ inline void Div(const ArithmeticParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4>.
-template <typename T>
-inline void Div(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- tflite::ArithmeticParams op_params;
- SetActivationParams(output_activation_min, output_activation_max, &op_params);
-
- Div(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
- output_data);
-}
-
inline void SubNonBroadcast(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const float* input1_data,
@@ -1703,7 +1268,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params,
const float* input2_data,
const RuntimeShape& output_shape,
float* output_data) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/float");
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/float");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@@ -1744,7 +1309,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params,
const uint8* input2_data,
const RuntimeShape& output_shape,
uint8* output_data) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/uint8");
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/uint8");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@@ -1808,7 +1373,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params,
const int32* input2_data,
const RuntimeShape& output_shape,
int32* output_data) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/int32");
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/int32");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@@ -1848,7 +1413,7 @@ void BroadcastSub4DSlow(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const T* input1_data,
const RuntimeShape& input2_shape, const T* input2_data,
const RuntimeShape& output_shape, T* output_data) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/templated");
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/templated");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@@ -1995,35 +1560,10 @@ inline void Concatenation(const ConcatenationParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4>.
-template <FusedActivationFunctionType Ac, typename Scalar>
-inline void Concatenation(int concat_dim, const Scalar* const* input_data,
- const Dims<4>* const* input_dims, int inputs_count,
- Scalar* output_data, const Dims<4>& output_dims) {
- // For now we don't have a model with a Concatenation with fused activation.
- TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
-
- std::vector<RuntimeShape> input_shapes(inputs_count);
- std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
- for (int i = 0; i < inputs_count; ++i) {
- ShapeFromDims(*input_dims[i], &input_shapes[i]);
- input_shapes_indirect[i] = &input_shapes[i];
- }
- tflite::ConcatenationParams op_params;
- op_params.axis = 3 - concat_dim;
- op_params.inputs_count = inputs_count;
-
- Concatenation(op_params, input_shapes_indirect.data(), input_data,
- DimsToShape(output_dims), output_data);
-}
-
// TODO(prabhumk): This is the same as the optimized implementation.
// TODO(prabhumk): The quantized implementation of concatentation isn't fully
// quantized as it takes scale as a floating point value. This should be fixed
// when optimizng this routine further.
-
-// template <>
inline void ConcatenationWithScaling(const ConcatenationParams& params,
const RuntimeShape* const* input_shapes,
const uint8* const* input_data,
@@ -2036,15 +1576,13 @@ inline void ConcatenationWithScaling(const ConcatenationParams& params,
const int32 output_zeropoint = params.output_zeropoint;
const float output_scale = params.output_scale;
- // The arguments input_zeropoint and input_scale are expected to be an array
- // that have the quantization parameters for all the inputs to the concat
- // operator.
- TFLITE_DCHECK_GT(inputs_count, 1);
- TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int concat_dimensions = output_shape.DimensionsCount();
+ TFLITE_DCHECK_LT(axis, concat_dimensions);
+
int64_t concat_size = 0;
for (int i = 0; i < inputs_count; i++) {
- TFLITE_DCHECK_EQ(input_shapes[i]->DimensionsCount(), 4);
- for (int j = 0; j < 4; j++) {
+ TFLITE_DCHECK_EQ(input_shapes[i]->DimensionsCount(), concat_dimensions);
+ for (int j = 0; j < concat_dimensions; j++) {
if (j != axis) {
MatchingDim(*input_shapes[i], j, output_shape, j);
}
@@ -2059,9 +1597,10 @@ inline void ConcatenationWithScaling(const ConcatenationParams& params,
// For all input arrays,
// FlatSize() = outer_size * Dims(axis) * base_inner_size;
int64_t base_inner_size = 1;
- for (int i = axis + 1; i < 4; ++i) {
+ for (int i = axis + 1; i < concat_dimensions; ++i) {
base_inner_size *= output_shape.Dims(i);
}
+
const float inverse_output_scale = 1.f / output_scale;
uint8* output_ptr = output_data;
for (int k = 0; k < outer_size; k++) {
@@ -2087,65 +1626,52 @@ inline void ConcatenationWithScaling(const ConcatenationParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4>.
-inline void Concatenation(int concat_dim, const uint8* const* input_data,
- const Dims<4>* const* input_dims,
- const int32* input_zeropoint,
- const float* input_scale, int inputs_count,
- uint8* output_data, const Dims<4>& output_dims,
- const int32 output_zeropoint,
- const float output_scale) {
- std::vector<RuntimeShape> input_shapes(inputs_count);
- std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
- for (int i = 0; i < inputs_count; ++i) {
- ShapeFromDims(*input_dims[i], &input_shapes[i]);
- input_shapes_indirect[i] = &input_shapes[i];
- }
- tflite::ConcatenationParams op_params;
- op_params.axis = 3 - concat_dim;
- op_params.input_zeropoint = input_zeropoint;
- op_params.input_scale = input_scale;
- op_params.inputs_count = inputs_count;
- op_params.output_zeropoint = output_zeropoint;
- op_params.output_scale = output_scale;
-
- ConcatenationWithScaling(op_params, input_shapes_indirect.data(), input_data,
- DimsToShape(output_dims), output_data);
-}
-
template <typename Scalar>
-void Pack(int dim, const Scalar* const* input_data,
- const Dims<4>* const* input_dims, int inputs_count,
- Scalar* output_data, const Dims<4>& output_dims) {
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+void Pack(const PackParams& params, const RuntimeShape* const* input_shapes,
+ const Scalar* const* input_data, const RuntimeShape& output_shape,
+ Scalar* output_data) {
+ const int dimensions = output_shape.DimensionsCount();
+ int axis = params.axis;
+ int inputs_count = params.inputs_count;
+
int outer_size = 1;
- for (int i = dim + 1; i < 4; i++) {
- outer_size *= output_dims.sizes[i];
+ for (int i = 0; i < axis; i++) {
+ outer_size *= output_shape.Dims(i);
}
- Scalar* output_ptr = output_data;
- const int copy_size = FlatSize(**input_dims) / outer_size;
- for (int k = 0; k < outer_size; k++) {
- for (int i = 0; i < inputs_count; ++i) {
- memcpy(output_ptr, input_data[i] + k * copy_size,
- copy_size * sizeof(Scalar));
- output_ptr += copy_size;
+ int copy_size = 1;
+ for (int i = params.axis + 1; i < dimensions; i++) {
+ copy_size *= output_shape.Dims(i);
+ }
+ TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
+
+ for (int i = 0; i < inputs_count; ++i) {
+ for (int k = 0; k < outer_size; k++) {
+ const Scalar* input_ptr = input_data[i] + copy_size * k;
+ int loc = k * inputs_count * copy_size + i * copy_size;
+ memcpy(output_data + loc, input_ptr, copy_size * sizeof(Scalar));
}
}
}
template <typename Scalar>
-void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims,
- int dimensions, int outputs_count, Scalar* const* output_datas,
- const Dims<4>& output_dims) {
+void Unpack(const UnpackParams& params, const RuntimeShape& input_shape,
+ const Scalar* input_data, const RuntimeShape& output_shape,
+ Scalar* const* output_datas) {
+ const int dimensions = input_shape.DimensionsCount();
+ const int outputs_count = params.num_split;
+
int outer_size = 1;
- for (int i = dimensions - axis; i < 4; i++) {
- outer_size *= input_dims.sizes[i];
+ for (int i = 0; i < params.axis; i++) {
+ outer_size *= input_shape.Dims(i);
}
+ int copy_size = 1;
+ for (int i = params.axis + 1; i < dimensions; i++) {
+ copy_size *= input_shape.Dims(i);
+ }
+ TFLITE_DCHECK_EQ(output_shape.FlatSize(), copy_size * outer_size);
- const int copy_size = FlatSize(input_dims) / outer_size / outputs_count;
- for (int k = 0; k < outer_size; k++) {
- for (int i = 0; i < outputs_count; ++i) {
+ for (int i = 0; i < outputs_count; ++i) {
+ for (int k = 0; k < outer_size; k++) {
Scalar* output_ptr = output_datas[i] + copy_size * k;
int loc = k * outputs_count * copy_size + i * copy_size;
memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
@@ -2154,18 +1680,29 @@ void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims,
}
template <typename Scalar>
-void Pack(int dim, const Scalar* const* input_data,
- const Dims<4>* const* input_dims, const int32* input_zeropoint,
- const float* input_scale, int inputs_count, Scalar* output_data,
- const Dims<4>& output_dims, const int32 output_zeropoint,
- const float output_scale) {
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+void PackWithScaling(const PackParams& params,
+ const RuntimeShape* const* input_shapes,
+ const uint8* const* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ const int dimensions = output_shape.DimensionsCount();
+ int axis = params.axis;
+ const int32* input_zeropoint = params.input_zeropoint;
+ const float* input_scale = params.input_scale;
+ int inputs_count = params.inputs_count;
+ const int32 output_zeropoint = params.output_zeropoint;
+ const float output_scale = params.output_scale;
+
int outer_size = 1;
- for (int i = dim + 1; i < 4; i++) {
- outer_size *= output_dims.sizes[i];
+ for (int i = 0; i < axis; i++) {
+ outer_size *= output_shape.Dims(i);
}
+ int copy_size = 1;
+ for (int i = axis + 1; i < dimensions; i++) {
+ copy_size *= output_shape.Dims(i);
+ }
+ TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
+
Scalar* output_ptr = output_data;
- const int copy_size = FlatSize(**input_dims) / outer_size;
const float inverse_output_scale = 1.f / output_scale;
for (int k = 0; k < outer_size; k++) {
for (int i = 0; i < inputs_count; ++i) {
@@ -2191,64 +1728,101 @@ void Pack(int dim, const Scalar* const* input_data,
}
}
-template <FusedActivationFunctionType Ac, typename Scalar>
-void DepthConcatenation(const Scalar* const* input_data,
- const Dims<4>* const* input_dims, int inputs_count,
- Scalar* output_data, const Dims<4>& output_dims) {
- Concatenation<Ac, Scalar>(0, input_data, input_dims, inputs_count,
- output_data, output_dims);
-}
-
-inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
- const float* prev_activ_data,
- const Dims<4>& prev_activ_dims, const float* weights_data,
- const Dims<4>& weights_dims, const float* bias_data,
- const Dims<4>& bias_dims, const float* prev_state_data,
- const Dims<4>& prev_state_dims, float* output_state_data,
- const Dims<4>& output_state_dims, float* output_activ_data,
- const Dims<4>& output_activ_dims, float* concat_temp_data,
- const Dims<4>& concat_temp_dims, float* activ_temp_data,
- const Dims<4>& activ_temp_dims) {
+template <typename Scalar>
+void DepthConcatenation(const ConcatenationParams& params,
+ const RuntimeShape* const* input_shapes,
+ const Scalar* const* input_data,
+ const RuntimeShape& output_shape, Scalar* output_data) {
+ auto params_copy = params;
+ params_copy.axis = 3;
+ Concatenation(params_copy, input_shapes, input_data, output_shape,
+ output_data);
+}
+
+inline void LstmCell(
+ const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
+ const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
+ const float* prev_activ_data, const RuntimeShape& weights_shape,
+ const float* weights_data, const RuntimeShape& unextended_bias_shape,
+ const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
+ const float* prev_state_data,
+ const RuntimeShape& unextended_output_state_shape, float* output_state_data,
+ const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
+ const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
+ const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape prev_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
+ const RuntimeShape bias_shape =
+ RuntimeShape::ExtendedShape(4, unextended_bias_shape);
+ const RuntimeShape prev_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
+ const RuntimeShape output_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
+ const RuntimeShape output_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
+ const RuntimeShape concat_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
+ const RuntimeShape activ_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+
+ const int weights_dim_count = weights_shape.DimensionsCount();
const int batches =
- MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3,
- output_state_dims, 3, output_activ_dims, 3);
+ MatchingDim(input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
+ output_state_shape, 0, output_activ_shape, 0);
const int height =
- MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2,
- output_state_dims, 2, output_activ_dims, 2);
+ MatchingDim(input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
+ output_state_shape, 1, output_activ_shape, 1);
const int width =
- MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1,
- output_state_dims, 1, output_activ_dims, 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
+ MatchingDim(input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
+ output_state_shape, 2, output_activ_shape, 2);
+ const int input_depth = input_shape.Dims(3);
+ const int prev_activ_depth = prev_activ_shape.Dims(3);
const int total_input_depth = prev_activ_depth + input_depth;
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
- TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
- 1);
+ TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
+ total_input_depth);
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
const int intern_activ_depth =
- MatchingArraySize(weights_dims, 1, bias_dims, 0);
- TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
+ MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
+ TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
+ intern_activ_depth * total_input_depth);
+ TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
const int output_depth =
- MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
- output_state_dims, 0, output_activ_dims, 0);
- TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
+ MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
+ 3, output_activ_shape, 3);
+ TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
// Concatenate prev_activ and input data together
std::vector<float const*> concat_input_arrays_data;
- std::vector<Dims<4> const*> concat_input_arrays_dims;
+ std::vector<RuntimeShape const*> concat_input_arrays_shapes;
concat_input_arrays_data.push_back(input_data);
concat_input_arrays_data.push_back(prev_activ_data);
- concat_input_arrays_dims.push_back(&input_dims);
- concat_input_arrays_dims.push_back(&prev_activ_dims);
- Concatenation<FusedActivationFunctionType::kNone, float>(
- 0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]),
- concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims);
+ concat_input_arrays_shapes.push_back(&input_shape);
+ concat_input_arrays_shapes.push_back(&prev_activ_shape);
+ tflite::ConcatenationParams concat_params;
+ concat_params.axis = 3;
+ concat_params.inputs_count = concat_input_arrays_data.size();
+ Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
+ &(concat_input_arrays_data[0]), concat_temp_shape,
+ concat_temp_data);
// Fully connected
- FullyConnected<FusedActivationFunctionType::kNone>(
- concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data,
- bias_dims, activ_temp_data, activ_temp_dims);
+ tflite::FullyConnectedParams fc_params;
+ fc_params.float_activation_min = std::numeric_limits<float>::lowest();
+ fc_params.float_activation_max = std::numeric_limits<float>::max();
+ FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
+ weights_data, bias_shape, bias_data, activ_temp_shape,
+ activ_temp_data);
// Memory state update (the LSTM "guts")
for (int b = 0; b < batches; ++b) {
@@ -2257,24 +1831,24 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
for (int c = 0; c < output_depth; ++c) {
const float input_gate =
1.f /
- (1.f + std::exp(-activ_temp_data[Offset(
- activ_temp_dims, 0 * output_depth + c, w, h, b)]));
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+ 0 * output_depth + c)]));
const float new_input = std::tanh(activ_temp_data[Offset(
- activ_temp_dims, 1 * output_depth + c, w, h, b)]);
+ activ_temp_shape, b, h, w, 1 * output_depth + c)]);
const float forget_gate =
1.f /
- (1.f + std::exp(-activ_temp_data[Offset(
- activ_temp_dims, 2 * output_depth + c, w, h, b)]));
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+ 2 * output_depth + c)]));
const float output_gate =
1.f /
- (1.f + std::exp(-activ_temp_data[Offset(
- activ_temp_dims, 3 * output_depth + c, w, h, b)]));
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+ 3 * output_depth + c)]));
const float new_state =
input_gate * new_input +
forget_gate *
- prev_state_data[Offset(prev_state_dims, c, w, h, b)];
- output_state_data[Offset(output_state_dims, c, w, h, b)] = new_state;
- output_activ_data[Offset(output_activ_dims, c, w, h, b)] =
+ prev_state_data[Offset(prev_state_shape, b, h, w, c)];
+ output_state_data[Offset(output_state_shape, b, h, w, c)] = new_state;
+ output_activ_data[Offset(output_activ_shape, b, h, w, c)] =
output_gate * std::tanh(new_state);
}
}
@@ -2367,52 +1941,90 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
// aiming for 16-bit fixed-point quantization of these internal nodes here.
//
template <int StateIntegerBits>
-void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
- const uint8* prev_activ_data_uint8,
- const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
- const Dims<4>& weights_dims, const int32* bias_data_int32,
- const Dims<4>& bias_dims, const int16* prev_state_data_int16,
- const Dims<4>& prev_state_dims, int16* output_state_data_int16,
- const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
- const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
- const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
- const Dims<4>& activ_temp_dims, int32 weights_zero_point,
- int32 accum_multiplier, int accum_shift,
- gemmlowp::GemmContext* gemm_context) {
+inline void LstmCell(
+ const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
+ const uint8* input_data_uint8,
+ const RuntimeShape& unextended_prev_activ_shape,
+ const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
+ const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
+ const int32* bias_data_int32,
+ const RuntimeShape& unextended_prev_state_shape,
+ const int16* prev_state_data_int16,
+ const RuntimeShape& unextended_output_state_shape,
+ int16* output_state_data_int16,
+ const RuntimeShape& unextended_output_activ_shape,
+ uint8* output_activ_data_uint8,
+ const RuntimeShape& unextended_concat_temp_shape,
+ uint8* concat_temp_data_uint8,
+ const RuntimeShape& unextended_activ_temp_shape,
+ int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) {
(void)gemm_context; // only used in optimized code.
+ int32 weights_zero_point = params.weights_zero_point;
+ int32 accum_multiplier = params.accum_multiplier;
+ int accum_shift = params.accum_shift;
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape prev_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
+ const RuntimeShape bias_shape =
+ RuntimeShape::ExtendedShape(4, unextended_bias_shape);
+ const RuntimeShape prev_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
+ const RuntimeShape output_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
+ const RuntimeShape output_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
+ const RuntimeShape concat_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
+ const RuntimeShape activ_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
// Gather dimensions information, and perform consistency checks.
- const int outer_size =
- MatchingFlatSizeSkipDim(input_dims, 0, prev_activ_dims, prev_state_dims,
- output_state_dims, output_activ_dims);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ const int outer_size = MatchingFlatSizeSkipDim(
+ input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
+ output_activ_shape);
+ const int input_depth = input_shape.Dims(3);
+ const int prev_activ_depth = prev_activ_shape.Dims(3);
const int total_input_depth = prev_activ_depth + input_depth;
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
- TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
- 1);
+ TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
+ total_input_depth);
const int intern_activ_depth =
- MatchingArraySize(weights_dims, 1, bias_dims, 0);
- TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
+ MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
+ TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
+ intern_activ_depth * total_input_depth);
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
+ TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
const int output_depth =
- MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
- output_state_dims, 0, output_activ_dims, 0);
- TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
- const int fc_batches = FlatSizeSkipDim(activ_temp_dims, 0);
+ MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
+ 3, output_activ_shape, 3);
+ TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
+ const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
const int fc_output_depth =
- MatchingArraySize(weights_dims, 1, activ_temp_dims, 0);
- const int fc_accum_depth = ArraySize(weights_dims, 0);
- TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth);
+ MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
+ const int fc_accum_depth = total_input_depth;
+ TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
// Depth-concatenate prev_activ and input data together.
uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
prev_activ_data_uint8};
- Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims};
- Concatenation<FusedActivationFunctionType::kNone, uint8>(
- 0, concat_input_arrays_data, concat_input_arrays_dims, 2,
- concat_temp_data_uint8, concat_temp_dims);
+ const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
+ &prev_activ_shape};
+ tflite::ConcatenationParams concat_params;
+ concat_params.axis = 3;
+ concat_params.inputs_count = 2;
+ Concatenation(concat_params, concat_input_arrays_shapes,
+ concat_input_arrays_data, concat_temp_shape,
+ concat_temp_data_uint8);
// Implementation of the fully connected node inside the LSTM cell.
// The operands are 8-bit integers, the accumulators are internally 32bit
@@ -2560,45 +2172,6 @@ void Split(const SplitParams& params, const RuntimeShape& input_shape,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4>.
-template <typename Scalar>
-void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
- int axis, int outputs_count, Scalar* const* output_data,
- const Dims<4>* const* output_dims) {
- std::vector<RuntimeShape> output_shapes(outputs_count);
- std::vector<const RuntimeShape*> output_shapes_indirect(outputs_count);
- for (int i = 0; i < outputs_count; ++i) {
- ShapeFromDims(*output_dims[i], &output_shapes[i]);
- output_shapes_indirect[i] = &output_shapes[i];
- }
- tflite::SplitParams op_params;
- op_params.axis = 3 - axis;
- op_params.num_split = outputs_count;
-
- Split(op_params, DimsToShape(input_dims), input_data,
- output_shapes_indirect.data(), output_data);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4>.
-template <FusedActivationFunctionType Ac, typename Scalar>
-void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
- int outputs_count, Scalar* const* output_data,
- const Dims<4>* const* output_dims) {
- TFLITE_DCHECK_GE(outputs_count, 1);
- for (int i = 0; i < outputs_count; i++) {
- /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3);
- /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
- /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
- }
- // For now we don't have a model with a Split with fused activation.
- TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
-
- TensorFlowSplit(input_data, input_dims, /*axis=*/0, outputs_count,
- output_data, output_dims);
-}
-
inline int NodeOffset(int b, int h, int w, int height, int width) {
return (b * height + h) * width + w;
}
@@ -2897,144 +2470,6 @@ inline void LocalResponseNormalization(
}
}
-inline void Softmax(const SoftmaxParams& params,
- const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
- const int trailing_dim = input_shape.DimensionsCount() - 1;
- const int outer_size =
- MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
- const int depth =
- MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
-
- for (int i = 0; i < outer_size; ++i) {
- // Find max element value which we'll use to ensure numerical stability
- // taking advantage of the following equality:
- // exp(x[i])/sum(exp(x[i])) == exp(x[i]+C)/sum(exp(x[i]+C))
- float max = std::numeric_limits<float>::lowest();
- for (int c = 0; c < depth; ++c) {
- max = std::max(max, input_data[i * depth + c]);
- }
-
- // Compute sum.
- float sum = 0.f;
- for (int c = 0; c < depth; ++c) {
- sum += std::exp((input_data[i * depth + c] - max) * params.beta);
- }
-
- // Compute result.
- for (int c = 0; c < depth; ++c) {
- output_data[i * depth + c] =
- std::exp((input_data[i * depth + c] - max) * params.beta) / sum;
- }
- }
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
- float beta, float* output_data,
- const RuntimeShape& output_shape) {
- SoftmaxParams params;
- params.beta = beta;
- Softmax(params, input_shape, input_data, output_shape, output_data);
-}
-
-inline void Softmax(const SoftmaxParams& params,
- const RuntimeShape& input_shape, const uint8* input_data,
- const RuntimeShape& output_shape, uint8* output_data) {
- const int32 input_beta_multiplier = params.input_multiplier;
- const int32 input_beta_left_shift = params.input_left_shift;
- const int diff_min = params.diff_min;
- // The representation chosen for the input to the exp() function is Q5.26.
- // We need to leave extra space since values that we skip might be as large as
- // -32 before multiplying by input_beta_multiplier, and therefore as large as
- // -16 afterwards. Note that exp(-8) is definitely not insignificant to
- // accumulation, but exp(-16) definitely is.
- static const int kScaledDiffIntegerBits = 5;
- static const int kAccumulationIntegerBits = 12;
- using FixedPointScaledDiff =
- gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
- using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
- using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
-
- const int trailing_dim = input_shape.DimensionsCount() - 1;
- const int outer_size =
- MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
- const int depth =
- MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
-
- for (int i = 0; i < outer_size; ++i) {
- uint8 max_in_row = 0;
- for (int c = 0; c < depth; ++c) {
- max_in_row = std::max(max_in_row, input_data[i * depth + c]);
- }
-
- FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
- for (int c = 0; c < depth; ++c) {
- int32 input_diff =
- static_cast<int32>(input_data[i * depth + c]) - max_in_row;
- if (input_diff >= diff_min) {
- const int32 input_diff_rescaled =
- MultiplyByQuantizedMultiplierGreaterThanOne(
- input_diff, input_beta_multiplier, input_beta_left_shift);
- const FixedPointScaledDiff scaled_diff_f8 =
- FixedPointScaledDiff::FromRaw(input_diff_rescaled);
- sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
- exp_on_negative_values(scaled_diff_f8));
- }
- }
-
- int32 fixed_sum_of_exps = sum_of_exps.raw();
- int headroom_plus_one =
- CountLeadingZeros(static_cast<uint32>(fixed_sum_of_exps));
- // This is the number of bits to the left of the binary point above 1.0.
- // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and
- // no later adjustment will be needed.
- int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one;
- int32 shifted_sum_minus_one = static_cast<int32>(
- (static_cast<uint32>(fixed_sum_of_exps) << headroom_plus_one) -
- (static_cast<uint32>(1) << 31));
-
- FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1(
- FixedPoint0::FromRaw(shifted_sum_minus_one));
-
- for (int c = 0; c < depth; ++c) {
- int32 input_diff =
- static_cast<int32>(input_data[i * depth + c]) - max_in_row;
- if (input_diff >= diff_min) {
- const int32 input_diff_rescaled =
- MultiplyByQuantizedMultiplierGreaterThanOne(
- input_diff, input_beta_multiplier, input_beta_left_shift);
- const FixedPointScaledDiff scaled_diff_f8 =
- FixedPointScaledDiff::FromRaw(input_diff_rescaled);
-
- FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
- int32 unsat_output = gemmlowp::RoundingDivideByPOT(
- (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
-
- output_data[i * depth + c] = static_cast<uint8>(
- std::max(std::min(unsat_output, static_cast<int32>(255)), 0));
-
- } else {
- output_data[i * depth + c] = 0;
- }
- }
- }
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy
-inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_beta_multiplier, int32 input_beta_left_shift,
- int diff_min, uint8* output_data,
- const RuntimeShape& output_shape) {
- SoftmaxParams params;
- params.input_multiplier = input_beta_multiplier;
- params.input_left_shift = input_beta_left_shift;
- params.diff_min = diff_min;
- Softmax(params, input_shape, input_data, output_shape, output_data);
-}
-
inline void LogSoftmax(const SoftmaxParams& params,
const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
@@ -3067,15 +2502,6 @@ inline void LogSoftmax(const SoftmaxParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy
-inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
- SoftmaxParams params;
- // No params currently used for float LogSoftmax.
- LogSoftmax(params, input_shape, input_data, output_shape, output_data);
-}
-
// Although currently the name of this function says that it cannot handle
// values less than 1, in practice it can handle as low as 1/x_max, where
// x_max is the largest representable input. In other words, the output range
@@ -3255,7 +2681,7 @@ inline void LogSoftmax(const SoftmaxParams& params,
std::max(diff_min - 1, // Note use of > below instead of >= above.
MultiplyByQuantizedMultiplierSmallerThanOneExp(
rescaled_diff_min, reverse_scaling_divisor,
- kReverseShift * reverse_scaling_right_shift));
+ -reverse_scaling_right_shift));
for (int c = 0; c < depth; ++c) {
int32 input_diff =
@@ -3280,24 +2706,7 @@ inline void LogSoftmax(const SoftmaxParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_multiplier, int32 input_left_shift,
- int32 reverse_scaling_divisor,
- int32 reverse_scaling_right_shift, int diff_min,
- uint8* output_data, const RuntimeShape& output_shape) {
- SoftmaxParams params;
- params.input_multiplier = input_multiplier;
- params.input_left_shift = input_left_shift;
- params.reverse_scaling_divisor = reverse_scaling_divisor;
- params.reverse_scaling_right_shift = reverse_scaling_right_shift;
- params.diff_min = diff_min;
- LogSoftmax(params, input_shape, input_data, output_shape, output_data);
-}
-
-inline void Logistic(const LogisticParams& params,
- const RuntimeShape& input_shape, const float* input_data,
+inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -3308,13 +2717,13 @@ inline void Logistic(const LogisticParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
- LogisticParams params;
- // No params currently needed by float Logistic.
- Logistic(params, input_shape, input_data, output_shape, output_data);
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Logistic(input_shape, input_data, output_shape, output_data);
}
inline void Logistic(const LogisticParams& params,
@@ -3358,20 +2767,6 @@ inline void Logistic(const LogisticParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_zero_point, int32 input_range_radius,
- int32 input_multiplier, int input_left_shift,
- uint8* output_data, const RuntimeShape& output_shape) {
- LogisticParams params;
- params.input_zero_point = input_zero_point;
- params.input_range_radius = input_range_radius;
- params.input_multiplier = input_multiplier;
- params.input_left_shift = input_left_shift;
- Logistic(params, input_shape, input_data, output_shape, output_data);
-}
-
inline void Logistic(const LogisticParams& params,
const RuntimeShape& input_shape, const int16* input_data,
const RuntimeShape& output_shape, int16* output_data) {
@@ -3391,18 +2786,8 @@ inline void Logistic(const LogisticParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
- const RuntimeShape& output_shape, int16* output_data) {
- LogisticParams params;
- // No params currently needed by int16 Logistic.
- Logistic(params, input_shape, input_data, output_shape, output_data);
-}
-
-inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
- const float* input_data, const RuntimeShape& output_shape,
- float* output_data) {
+inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
@@ -3412,13 +2797,13 @@ inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
- TanhParams params;
- // Currently no params needed for float Tanh.
- Tanh(params, input_shape, input_data, output_shape, output_data);
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Tanh(input_shape, input_data, output_shape, output_data);
}
inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
@@ -3464,20 +2849,6 @@ inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_zero_point, int32 input_range_radius,
- int32 input_multiplier, int input_left_shift,
- uint8* output_data, const RuntimeShape& output_shape) {
- TanhParams params;
- params.input_zero_point = input_zero_point;
- params.input_range_radius = input_range_radius;
- params.input_multiplier = input_multiplier;
- params.input_left_shift = input_left_shift;
- Tanh(params, input_shape, input_data, output_shape, output_data);
-}
-
inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
const int16* input_data, const RuntimeShape& output_shape,
int16* output_data) {
@@ -3512,16 +2883,6 @@ inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
- int input_left_shift, int16* output_data,
- const RuntimeShape& output_shape) {
- TanhParams params;
- params.input_left_shift = input_left_shift;
- Tanh(params, input_shape, input_data, output_shape, output_data);
-}
-
inline void Dequantize(const tflite::DequantizationParams& op_params,
const RuntimeShape& input_shape, const uint8* input_data,
const RuntimeShape& output_shape, float* output_data) {
@@ -3536,19 +2897,6 @@ inline void Dequantize(const tflite::DequantizationParams& op_params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4>.
-inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
- int32 zero_point, double scale, float* output_data,
- const Dims<4>& output_dims) {
- tflite::DequantizationParams op_params;
- op_params.zero_point = zero_point;
- op_params.scale = scale;
-
- Dequantize(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_dims), output_data);
-}
-
inline void FakeQuant(const tflite::FakeQuantParams& op_params,
const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
@@ -3572,20 +2920,6 @@ inline void FakeQuant(const tflite::FakeQuantParams& op_params,
output_data, flat_size);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4>.
-inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
- float rmin, float rmax, int num_bits, float* output_data,
- const Dims<4>& output_dims) {
- tflite::FakeQuantParams op_params;
- op_params.num_bits = num_bits;
- op_params.minmax.min = rmin;
- op_params.minmax.max = rmax;
-
- FakeQuant(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_dims), output_data);
-}
-
template <typename SrcT, typename DstT>
inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
const RuntimeShape& output_shape, DstT* output_data) {
@@ -3609,15 +2943,21 @@ inline void Floor(const RuntimeShape& input_shape, const float* input_data,
template <typename T>
inline void Gather(const tflite::GatherParams& op_params,
- const RuntimeShape& input_shape, const T* input_data,
- const RuntimeShape& coords_shape, const int32* coords_data,
- const RuntimeShape& output_shape, T* output_data) {
- // Enable these checks when moving legacy ops to legacy_reference_ops.
- //
- // TFLITE_DCHECK_EQ(coords_shape.DimensionsCount(), 1);
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data, const RuntimeShape& coords_shape,
+ const int32* coords_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
const int input_rank = op_params.input_rank;
const int gather_dimensions = output_shape.DimensionsCount();
- TFLITE_DCHECK_LE(input_shape.DimensionsCount(), gather_dimensions);
+ TFLITE_DCHECK_GE(input_shape.DimensionsCount(), gather_dimensions);
const int axis = gather_dimensions - input_rank;
TFLITE_DCHECK_LT(axis, gather_dimensions);
TFLITE_DCHECK_GE(axis, 0);
@@ -3639,23 +2979,6 @@ inline void Gather(const tflite::GatherParams& op_params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4> version.
-// When moving legacy ops to legacy_reference_ops, replace content with looser
-// implementation.
-template <typename T>
-inline void Gather(const T* input_data, const Dims<4>& input_dims,
- int input_rank, const int32* coords_data,
- const Dims<4>& coords_dims, T* output_data,
- const Dims<4>& output_dims) {
- tflite::GatherParams op_params;
- op_params.input_rank = input_rank;
-
- Gather(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(coords_dims), coords_data, DimsToShape(output_dims),
- output_data);
-}
-
template <typename T>
inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
const RuntimeShape& unextended_input_shape,
@@ -3985,58 +3308,6 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline uint32 LegacyReverseBits32(uint32 n) {
- n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1);
- n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2);
- n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4);
- return (((n & 0xFF) << 24) | ((n & 0xFF00) << 8) | ((n & 0xFF0000) >> 8) |
- ((n & 0xFF000000) >> 24));
-}
-
-inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) {
- TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
- TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
-
- std::reverse(p->start_indices, p->start_indices + p->start_indices_count);
- std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count);
- std::reverse(p->strides, p->strides + p->strides_count);
-
- p->begin_mask = LegacyReverseBits32(static_cast<uint32>(p->begin_mask)) >>
- (32 - p->start_indices_count);
- p->ellipsis_mask =
- LegacyReverseBits32(static_cast<uint32>(p->ellipsis_mask)) >>
- (32 - p->start_indices_count);
- p->end_mask = LegacyReverseBits32(static_cast<uint32>(p->end_mask)) >>
- (32 - p->start_indices_count);
- p->new_axis_mask =
- LegacyReverseBits32(static_cast<uint32>(p->new_axis_mask)) >>
- (32 - p->start_indices_count);
- p->shrink_axis_mask =
- LegacyReverseBits32(static_cast<uint32>(p->shrink_axis_mask)) >>
- (32 - p->start_indices_count);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T>
-inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
- int begin_mask, int end_mask, int shrink_axis_mask,
- const std::vector<int>& start_indices,
- const std::vector<int>& stop_indices,
- const std::vector<int>& strides, T* output_data,
- const Dims<4>& output_dims) {
- TFLITE_DCHECK_EQ(start_indices.size(), 4);
- auto op_params = strided_slice::BuildStridedSliceParams(
- begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
- strides);
- StridedSliceReverseIndices(&op_params);
-
- StridedSlice(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_dims), output_data);
-}
-
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
@@ -4302,32 +3573,19 @@ inline void Mean(const tflite::MeanParams& op_params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4>.
-template <typename T>
-inline void Mean(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& reduction_indices, T* output_data,
- const Dims<4>& output_dims) {
- tflite::MeanParams op_params;
- op_params.axis_count = reduction_indices.size();
- for (int i = 0; i < op_params.axis_count; ++i) {
- op_params.axis[i] = reduction_indices[op_params.axis_count - 1 - i];
- }
-
- Mean(op_params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
- output_data);
-}
-
// Computes the mean of elements across dimensions given in axis.
// It does so in two stages, first calculates the sum of elements along the axis
// then divides it by the number of element in axis for quantized values.
template <typename T, typename U>
-inline bool Mean(const T* input_data, int32 input_zero_point, float input_scale,
- const int* input_dims, const int input_num_dims,
- T* output_data, int32 output_zero_point, float output_scale,
- const int* output_dims, const int output_num_dims,
- const int* axis, const int num_axis_dimensions, bool keep_dims,
- int* temp_index, int* resolved_axis, U* temp_sum) {
+inline bool QuantizedMeanOrSum(const T* input_data, int32 input_zero_point,
+ float input_scale, const int* input_dims,
+ const int input_num_dims, T* output_data,
+ int32 output_zero_point, float output_scale,
+ const int* output_dims,
+ const int output_num_dims, const int* axis,
+ const int num_axis_dimensions, bool keep_dims,
+ int* temp_index, int* resolved_axis, U* temp_sum,
+ bool compute_sum) {
// Reset output data.
size_t num_outputs = 1;
for (int idx = 0; idx < output_num_dims; ++idx) {
@@ -4369,14 +3627,24 @@ inline bool Mean(const T* input_data, int32 input_zero_point, float input_scale,
if (num_elements_in_axis > 0) {
const float scale = input_scale / output_scale;
- const float bias = -input_zero_point * scale;
- for (size_t idx = 0; idx < num_outputs; ++idx) {
- float float_mean = static_cast<float>(temp_sum[idx]) /
- static_cast<float>(num_elements_in_axis);
-
- // Convert to float value.
- output_data[idx] =
- static_cast<T>(round(float_mean * scale + bias)) + output_zero_point;
+ if (compute_sum) {
+ // TODO(b/116341117): Eliminate float and do this completely in 8bit.
+ const float bias = -input_zero_point * scale * num_elements_in_axis + 0.5;
+ for (size_t idx = 0; idx < num_outputs; ++idx) {
+ const U value = static_cast<U>(round(temp_sum[idx] * scale + bias)) +
+ output_zero_point;
+ output_data[idx] = static_cast<T>(value);
+ }
+ } else {
+ const float bias = -input_zero_point * scale + 0.5;
+ for (size_t idx = 0; idx < num_outputs; ++idx) {
+ float float_mean = static_cast<float>(temp_sum[idx]) /
+ static_cast<float>(num_elements_in_axis);
+
+ // Convert to float value.
+ output_data[idx] = static_cast<T>(round(float_mean * scale + bias)) +
+ output_zero_point;
+ }
}
}
return true;
@@ -4394,6 +3662,16 @@ void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
}
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape,
@@ -4406,6 +3684,16 @@ void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
}
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T, typename Op>
void MaximumMinimumBroadcast4DSlow(const RuntimeShape& unextended_input1_shape,
const T* input1_data,
@@ -4481,6 +3769,16 @@ void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
std::greater<T1>());
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T1, typename T2, typename T3>
+inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
+ const RuntimeShape& input2_shape, const T3* input2_data,
+ const RuntimeShape& output_shape, T2* output_data) {
+ // Drop shape of second input: not needed.
+ ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
void Transpose(const TransposeParams& params,
const RuntimeShape& unextended_input_shape, const T* input_data,
@@ -4532,35 +3830,30 @@ void Transpose(const TransposeParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T>
-void Transpose(const T* input, const Dims<4>& input_dims, T* output,
- const Dims<4>& output_dims, const int* permuted_axes) {
- TransposeParams params;
- params.perm_count = 4;
- for (int i = 0; i < 4; ++i) {
- params.perm[i] = 3 - permuted_axes[3 - i];
- }
- Transpose(params, DimsToShape(input_dims), input, DimsToShape(output_dims),
- output);
-}
-
-inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float* output_data,
- const Dims<4>& output_dims, float* /*im2col_data*/,
- const Dims<4>& /*im2col_dims*/) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
- const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+inline void TransposeConv(
+ const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ (void)im2col_data; // only used in optimized code.
+ (void)im2col_shape; // only used in optimized code.
+
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
// Although transpose convolution simplifies to convolution with transposed
// weights for strides of 1, non-unitary striding complicates matters. To
@@ -4569,7 +3862,7 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
// computing their influence on the output, rather than looping through the
// output elements in the typical "gather" access pattern of a conv. We
// therefore must initialize the output array to zero.
- const int num_elements = FlatSize(output_dims);
+ const int num_elements = output_shape.FlatSize();
for (int i = 0; i < num_elements; i++) {
output_data[i] = 0.0f;
}
@@ -4592,13 +3885,14 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
// We cannot accumulate out of bounds
if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) &&
(out_y < output_height)) {
- float input_value = input_data[Offset(input_dims, in_channel,
- in_x, in_y, batch)];
+ float input_value = input_data[Offset(
+ input_shape, batch, in_y, in_x, in_channel)];
float filter_value =
- filter_data[Offset(filter_dims, in_channel, filter_x,
- filter_y, out_channel)];
- output_data[Offset(output_dims, out_channel, out_x, out_y,
- batch)] += input_value * filter_value;
+ filter_data[Offset(filter_shape, out_channel, filter_y,
+ filter_x, in_channel)];
+ output_data[Offset(output_shape, batch, out_y, out_x,
+ out_channel)] +=
+ input_value * filter_value;
}
}
}
@@ -4662,19 +3956,6 @@ inline void Comparison(const ComparisonParams& op_params,
input2_data, output_shape, output_data);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T, ComparisonFn<T> F>
-inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- bool* output_data, const Dims<4>& output_dims) {
- ComparisonParams op_params;
- // No parameters needed.
- ComparisonImpl<T, F>(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data);
-}
-
template <typename T, ComparisonFn<int32> F>
inline void ComparisonWithScaling(
const ComparisonParams& op_params, const RuntimeShape& input1_shape,
@@ -4705,30 +3986,6 @@ inline void ComparisonWithScaling(
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T, ComparisonFn<int32> F>
-inline void Comparison(int left_shift, const T* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const T* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, bool* output_data,
- const Dims<4>& output_dims) {
- tflite::ComparisonParams op_params;
- op_params.left_shift = left_shift;
- op_params.input1_offset = input1_offset;
- op_params.input1_multiplier = input1_multiplier;
- op_params.input1_shift = kReverseShift * input1_shift;
- op_params.input2_offset = input2_offset;
- op_params.input2_multiplier = input2_multiplier;
- op_params.input2_shift = kReverseShift * input2_shift;
-
- ComparisonWithScaling<T, F>(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data);
-}
-
template <typename T, ComparisonFn<T> F>
inline void BroadcastComparison4DSlowImpl(
const ComparisonParams& op_params,
@@ -4772,22 +4029,6 @@ inline void BroadcastComparison4DSlow(const ComparisonParams& op_params,
output_shape, output_data);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T, ComparisonFn<T> F>
-inline void BroadcastComparison(const T* input1_data,
- const Dims<4>& input1_dims,
- const T* input2_data,
- const Dims<4>& input2_dims, bool* output_data,
- const Dims<4>& output_dims) {
- ComparisonParams op_params;
- // No parameters needed.
- BroadcastComparison4DSlowImpl<T, F>(op_params, DimsToShape(input1_dims),
- input1_data, DimsToShape(input2_dims),
- input2_data, DimsToShape(output_dims),
- output_data);
-}
-
template <typename T, ComparisonFn<int32> F>
inline void BroadcastComparison4DSlowWithScaling(
const ComparisonParams& op_params,
@@ -4838,78 +4079,7 @@ inline void BroadcastComparison4DSlowWithScaling(
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T, ComparisonFn<int32> F>
-inline void BroadcastComparison(int left_shift, const T* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const T* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset,
- int32 input2_multiplier, int input2_shift,
- bool* output_data, const Dims<4>& output_dims) {
- ComparisonParams op_params;
-
- op_params.left_shift = left_shift;
- op_params.input1_offset = input1_offset;
- op_params.input1_multiplier = input1_multiplier;
- op_params.input1_shift = kReverseShift * input1_shift;
- op_params.input2_offset = input2_offset;
- op_params.input2_multiplier = input2_multiplier;
- op_params.input2_shift = kReverseShift * input2_shift;
-
- BroadcastComparison4DSlowWithScaling<T, F>(
- op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
- output_data);
-}
-
#define TFLITE_COMPARISON_OP(name) \
- template <typename T> \
- inline void name(const T* input1_data, const Dims<4>& input1_dims, \
- const T* input2_data, const Dims<4>& input2_dims, \
- bool* output_data, const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label(#name); \
- Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
- input2_dims, output_data, output_dims); \
- } \
- template <typename T> \
- inline void name( \
- int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
- int32 input1_offset, int32 input1_multiplier, int input1_shift, \
- const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
- int32 input2_multiplier, int input2_shift, bool* output_data, \
- const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
- Comparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
- input1_offset, input1_multiplier, input1_shift, \
- input2_data, input2_dims, input2_offset, \
- input2_multiplier, input2_shift, output_data, \
- output_dims); \
- } \
- template <typename T> \
- inline void Broadcast##name( \
- const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
- const Dims<4>& input2_dims, bool* output_data, \
- const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \
- BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
- input2_dims, output_data, output_dims); \
- } \
- template <typename T> \
- inline void Broadcast##name( \
- int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
- int32 input1_offset, int32 input1_multiplier, int input1_shift, \
- const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
- int32 input2_multiplier, int input2_shift, bool* output_data, \
- const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \
- BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
- input1_offset, input1_multiplier, \
- input1_shift, input2_data, input2_dims, \
- input2_offset, input2_multiplier, \
- input2_shift, output_data, output_dims); \
- } \
inline void name(const ComparisonParams& op_params, \
const RuntimeShape& input1_shape, const float* input1_data, \
const RuntimeShape& input2_shape, const float* input2_data, \
@@ -4919,22 +4089,44 @@ inline void BroadcastComparison(int left_shift, const T* input1_data,
input2_data, output_shape, output_data); \
} \
template <typename T> \
+ inline void name##NoScaling( \
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
+ const T* input1_data, const RuntimeShape& input2_shape, \
+ const T* input2_data, const RuntimeShape& output_shape, \
+ bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label(#name "NoScaling"); \
+ ComparisonImpl<T, name##Fn>(op_params, input1_shape, input1_data, \
+ input2_shape, input2_data, output_shape, \
+ output_data); \
+ } \
+ template <typename T> \
inline void name##WithScaling( \
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
const T* input1_data, const RuntimeShape& input2_shape, \
const T* input2_data, const RuntimeShape& output_shape, \
bool* output_data) { \
- gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
+ gemmlowp::ScopedProfilingLabel label(#name "WithScaling/8bit"); \
ComparisonWithScaling<T, name##Fn>(op_params, input1_shape, input1_data, \
input2_shape, input2_data, \
output_shape, output_data); \
} \
+ template <typename T> \
+ inline void Broadcast4DSlow##name##NoScaling( \
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
+ const T* input1_data, const RuntimeShape& input2_shape, \
+ const T* input2_data, const RuntimeShape& output_shape, \
+ bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "NoScaling"); \
+ BroadcastComparison4DSlowImpl<T, name##Fn>( \
+ op_params, input1_shape, input1_data, input2_shape, input2_data, \
+ output_shape, output_data); \
+ } \
inline void Broadcast4DSlow##name( \
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
const float* input1_data, const RuntimeShape& input2_shape, \
const float* input2_data, const RuntimeShape& output_shape, \
bool* output_data) { \
- gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \
+ gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name); \
BroadcastComparison4DSlow<name##Fn>(op_params, input1_shape, input1_data, \
input2_shape, input2_data, \
output_shape, output_data); \
@@ -4945,7 +4137,7 @@ inline void BroadcastComparison(int left_shift, const T* input1_data,
const T* input1_data, const RuntimeShape& input2_shape, \
const T* input2_data, const RuntimeShape& output_shape, \
bool* output_data) { \
- gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \
+ gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "/8bit"); \
BroadcastComparison4DSlowWithScaling<T, name##Fn>( \
op_params, input1_shape, input1_data, input2_shape, input2_data, \
output_shape, output_data); \
@@ -4972,19 +4164,6 @@ void Select(const RuntimeShape& input_condition_shape,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename D, typename T>
-inline void Select(const D* input_condition_data,
- const Dims<4>& input_condition_dims, const T* input_x_data,
- const Dims<4>& input_x_dims, const T* input_y_data,
- const Dims<4>& input_y_dims, T* output_data,
- const Dims<4>& output_dims) {
- Select(DimsToShape(input_condition_dims), input_condition_data,
- DimsToShape(input_x_dims), input_x_data, DimsToShape(input_y_dims),
- input_y_data, DimsToShape(output_dims), output_data);
-}
-
template <typename D, typename T>
void RankOneSelect(const RuntimeShape& input_condition_shape,
const D* input_condition_data,
@@ -5006,20 +4185,6 @@ void RankOneSelect(const RuntimeShape& input_condition_shape,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename D, typename T>
-inline void RankOneSelect(const D* input_condition_data,
- const Dims<4>& input_condition_dims,
- const T* input_x_data, const Dims<4>& input_x_dims,
- const T* input_y_data, const Dims<4>& input_y_dims,
- T* output_data, const Dims<4>& output_dims) {
- RankOneSelect(DimsToShape(input_condition_dims), input_condition_data,
- DimsToShape(input_x_dims), input_x_data,
- DimsToShape(input_y_dims), input_y_data,
- DimsToShape(output_dims), output_data);
-}
-
// For easy implementation, the indices is always a vector of size-4 vectors.
template <typename T, typename TI>
inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
@@ -5061,16 +4226,6 @@ inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T, typename TI>
-inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
- const T* values, T default_value, T* output_data,
- const Dims<4>& output_dims, bool value_is_scalar) {
- SparseToDense(indices, values, default_value, value_is_scalar,
- DimsToShape(output_dims), output_data);
-}
-
template <typename T>
inline void Pow(const RuntimeShape& input1_shape, const T* input1_data,
const RuntimeShape& input2_shape, const T* input2_data,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/softmax.h b/tensorflow/contrib/lite/kernels/internal/reference/softmax.h
new file mode 100644
index 0000000000..7d44296134
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/reference/softmax.h
@@ -0,0 +1,179 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
+
+#include "fixedpoint/fixedpoint.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/round.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace reference_ops {
+
+inline void Softmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+
+ for (int i = 0; i < outer_size; ++i) {
+ // Find max element value which we'll use to ensure numerical stability
+ // taking advantage of the following equality:
+ // exp(x[i])/sum(exp(x[i])) == exp(x[i]+C)/sum(exp(x[i]+C))
+ float max = std::numeric_limits<float>::lowest();
+ for (int c = 0; c < depth; ++c) {
+ max = std::max(max, input_data[i * depth + c]);
+ }
+
+ // Compute sum.
+ float sum = 0.f;
+ for (int c = 0; c < depth; ++c) {
+ sum += std::exp((input_data[i * depth + c] - max) * params.beta);
+ }
+
+ // Compute result.
+ for (int c = 0; c < depth; ++c) {
+ output_data[i * depth + c] =
+ std::exp((input_data[i * depth + c] - max) * params.beta) / sum;
+ }
+ }
+}
+
+inline void Softmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ const int32 input_beta_multiplier = params.input_multiplier;
+ const int32 input_beta_left_shift = params.input_left_shift;
+ const int diff_min = params.diff_min;
+ // The representation chosen for the input to the exp() function is Q5.26.
+ // We need to leave extra space since values that we skip might be as large as
+ // -32 before multiplying by input_beta_multiplier, and therefore as large as
+ // -16 afterwards. Note that exp(-8) is definitely not insignificant to
+ // accumulation, but exp(-16) definitely is.
+ static const int kScaledDiffIntegerBits = 5;
+ static const int kAccumulationIntegerBits = 12;
+ using FixedPointScaledDiff =
+ gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
+ using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
+ using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
+
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+
+ for (int i = 0; i < outer_size; ++i) {
+ uint8 max_in_row = 0;
+ for (int c = 0; c < depth; ++c) {
+ max_in_row = std::max(max_in_row, input_data[i * depth + c]);
+ }
+
+ FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
+ for (int c = 0; c < depth; ++c) {
+ int32 input_diff =
+ static_cast<int32>(input_data[i * depth + c]) - max_in_row;
+ if (input_diff >= diff_min) {
+ const int32 input_diff_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_diff, input_beta_multiplier, input_beta_left_shift);
+ const FixedPointScaledDiff scaled_diff_f8 =
+ FixedPointScaledDiff::FromRaw(input_diff_rescaled);
+ sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
+ exp_on_negative_values(scaled_diff_f8));
+ }
+ }
+
+ int32 fixed_sum_of_exps = sum_of_exps.raw();
+ int headroom_plus_one =
+ CountLeadingZeros(static_cast<uint32>(fixed_sum_of_exps));
+ // This is the number of bits to the left of the binary point above 1.0.
+ // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and
+ // no later adjustment will be needed.
+ int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one;
+ int32 shifted_sum_minus_one = static_cast<int32>(
+ (static_cast<uint32>(fixed_sum_of_exps) << headroom_plus_one) -
+ (static_cast<uint32>(1) << 31));
+
+ FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1(
+ FixedPoint0::FromRaw(shifted_sum_minus_one));
+
+ for (int c = 0; c < depth; ++c) {
+ int32 input_diff =
+ static_cast<int32>(input_data[i * depth + c]) - max_in_row;
+ if (input_diff >= diff_min) {
+ const int32 input_diff_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_diff, input_beta_multiplier, input_beta_left_shift);
+ const FixedPointScaledDiff scaled_diff_f8 =
+ FixedPointScaledDiff::FromRaw(input_diff_rescaled);
+
+ FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
+ int32 unsat_output = gemmlowp::RoundingDivideByPOT(
+ (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
+
+ output_data[i * depth + c] = static_cast<uint8>(
+ std::max(std::min(unsat_output, static_cast<int32>(255)),
+ static_cast<int32>(0)));
+
+ } else {
+ output_data[i * depth + c] = 0;
+ }
+ }
+ }
+}
+
+// Performs softmax along the input of size (input_size * batch_size).
+inline void Softmax(const float* in, const int input_size, const int batch_size,
+ const float beta, float* out) {
+ // TF_LITE_ASSERT(input_size > 0);
+
+ // For each batch
+ for (int b = 0; b < batch_size; b++) {
+ // Find the max coeff.
+ float max_coeff = in[0];
+ for (int i = 1; i < input_size; i++) {
+ if (in[i] > max_coeff) max_coeff = in[i];
+ }
+
+ // Compute the normalized sum of exps.
+ float exp_sum = 0.0;
+ for (int i = 0; i < input_size; i++) {
+ out[i] = std::exp((in[i] - max_coeff) * beta);
+ exp_sum += out[i];
+ }
+
+ // Divide by the sum of exps.
+ float reciprocal_sum_exp = 1.f / exp_sum;
+ for (int i = 0; i < input_size; i++) {
+ out[i] *= reciprocal_sum_exp;
+ }
+
+ // Advance in and out pointers for the next batch.
+ in += input_size;
+ out += input_size;
+ }
+}
+
+} // namespace reference_ops
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc
index ca94e7740e..831fb3c243 100644
--- a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc
@@ -43,11 +43,15 @@ void RunSoftmaxFloatReference(const uint8* input_data,
// Reference data generated via Dequant of input into float, and then applying
// float Softmax.
- reference_ops::Dequantize(
- input_data, ToRuntimeDims(shape_common), input_offset, input_scale,
- reference_dequant_data.data(), ToRuntimeDims(shape_common));
- optimized_ops::Softmax(reference_dequant_data.data(), shape_common, beta,
- reference_output_float_data.data(), shape_common);
+ DequantizationParams dq_params;
+ dq_params.zero_point = input_offset;
+ dq_params.scale = input_scale;
+ reference_ops::Dequantize(dq_params, shape_common, input_data, shape_common,
+ reference_dequant_data.data());
+ SoftmaxParams sm_params;
+ sm_params.beta = beta;
+ optimized_ops::Softmax(sm_params, shape_common, reference_dequant_data.data(),
+ shape_common, reference_output_float_data.data());
// Work with quantized scaling for Softmax, under which 256 represents 1, but
// we limit this to 255.
for (int i = 0; i < ref_buffer_size; i++) {
@@ -116,12 +120,14 @@ void RunOneSoftmaxTest(const uint8* input_data,
const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits,
input_beta_left_shift);
- optimized_ops::Softmax(input_data, shape_common, input_beta_multiplier,
- input_beta_left_shift, diff_min,
- optimized_softmax_output.data(), shape_common);
- reference_ops::Softmax(input_data, shape_common, input_beta_multiplier,
- input_beta_left_shift, diff_min,
- reference_quant_softmax_output.data(), shape_common);
+ SoftmaxParams params;
+ params.input_multiplier = input_beta_multiplier;
+ params.input_left_shift = input_beta_left_shift;
+ params.diff_min = diff_min;
+ optimized_ops::Softmax(params, shape_common, input_data, shape_common,
+ optimized_softmax_output.data());
+ reference_ops::Softmax(params, shape_common, input_data, shape_common,
+ reference_quant_softmax_output.data());
CheckOutputData(optimized_softmax_output.data(),
reference_float_softmax_output.data(), shape_common,
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h
index 13106456df..689cea03e7 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor.h
@@ -37,10 +37,6 @@ inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) {
: nullptr;
}
-inline Dims<4> GetTensorDims(std::vector<int32_t> data) {
- return GetTensorDims(data.data(), data.size());
-}
-
inline RuntimeShape GetTensorShape(std::vector<int32_t> data) {
return RuntimeShape(data.size(), data.data());
}
@@ -56,20 +52,20 @@ class VectorOfTensors {
int num_tensors = tensor_list.size;
all_data_.reserve(num_tensors);
- all_dims_.reserve(num_tensors);
- all_dims_ptr_.reserve(num_tensors);
+ all_shape_.reserve(num_tensors);
+ all_shape_ptr_.reserve(num_tensors);
for (int i = 0; i < num_tensors; ++i) {
TfLiteTensor* t = &context.tensors[tensor_list.data[i]];
all_data_.push_back(GetTensorData<T>(t));
- all_dims_.push_back(GetTensorDims(t));
+ all_shape_.push_back(GetTensorShape(t));
}
// Taking the pointer from inside a std::vector is only OK if the vector is
- // never modified, so we populate all_dims in the previous loop and then we
+ // never modified, so we populate all_shape in the previous loop and then we
// are free to grab iterators here.
for (int i = 0; i < num_tensors; ++i) {
- all_dims_ptr_.push_back(&all_dims_[i]);
+ all_shape_ptr_.push_back(&all_shape_[i]);
}
}
// Return a pointer to the data pointers of all tensors in the list. For
@@ -78,16 +74,16 @@ class VectorOfTensors {
// f[0][1] is the second element of the first tensor.
T* const* data() const { return all_data_.data(); }
- // Return a pointer the dim pointers of all tensors in the list. For
+ // Return a pointer the shape pointers of all tensors in the list. For
// example:
- // const Dims<4>* const* d = v.dims();
+ // const RuntimeShape* const* d = v.dims();
// dims[1] are the dimensions of the second tensor in the list.
- const Dims<4>* const* dims() const { return all_dims_ptr_.data(); }
+ const RuntimeShape* const* shapes() const { return all_shape_ptr_.data(); }
private:
std::vector<T*> all_data_;
- std::vector<Dims<4>> all_dims_;
- std::vector<Dims<4>*> all_dims_ptr_;
+ std::vector<RuntimeShape> all_shape_;
+ std::vector<RuntimeShape*> all_shape_ptr_;
};
// A list of quantized tensors in a format that can be used by kernels like
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
index 77e22a08b4..9f5b33d217 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
@@ -86,39 +86,6 @@ inline const bool* GetTensorData(const TfLiteTensor* tensor) {
return tensor != nullptr ? tensor->data.b : nullptr;
}
-inline int RemapDim(int max_dimensions, int d) {
- return max_dimensions - d - 1;
-}
-
-// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object
-// even if the original tensors were not 4D. We should consider rewriting them
-// to take a more generic 'shape' object.
-inline Dims<4> GetTensorDims(const int data[], const int size) {
- Dims<4> d;
- for (int i = 0; i < 4; ++i) {
- int src = size - i - 1;
- if (src >= 0) {
- d.sizes[i] = data[src];
- } else {
- d.sizes[i] = 1;
- }
- }
- d.strides[0] = 1;
- for (int i = 1; i < 4; i++) {
- d.strides[i] = d.strides[i - 1] * d.sizes[i - 1];
- }
- return d;
-}
-
-inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) {
- if (tensor == nullptr) {
- return Dims<4>();
- }
-
- auto* dims = tensor->dims;
- return GetTensorDims(dims->data, dims->size);
-}
-
inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) {
if (tensor == nullptr) {
return RuntimeShape();
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc
index bf2068d320..2ed73ba82d 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc
@@ -21,28 +21,32 @@ namespace {
using ::testing::ElementsAre;
-TEST(TensorTest, GetTensorDims4D) {
- Dims<4> d = GetTensorDims({2, 3, 4, 5});
- EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 2));
- EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60));
+TEST(TensorTest, GetTensorShape4D) {
+ RuntimeShape d = GetTensorShape({2, 3, 4, 5});
+ EXPECT_THAT(
+ std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()),
+ ElementsAre(2, 3, 4, 5));
}
-TEST(TensorTest, GetTensorDims3D) {
- Dims<4> d = GetTensorDims({3, 4, 5});
- EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 1));
- EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60));
+TEST(TensorTest, GetTensorShape3D) {
+ RuntimeShape d = GetTensorShape({3, 4, 5});
+ EXPECT_THAT(
+ std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()),
+ ElementsAre(3, 4, 5));
}
-TEST(TensorTest, GetTensorDims2D) {
- Dims<4> d = GetTensorDims({4, 5});
- EXPECT_THAT(d.sizes, ElementsAre(5, 4, 1, 1));
- EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 20));
+TEST(TensorTest, GetTensorShape2D) {
+ RuntimeShape d = GetTensorShape({4, 5});
+ EXPECT_THAT(
+ std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()),
+ ElementsAre(4, 5));
}
-TEST(TensorTest, GetTensorDims1D) {
- Dims<4> d = GetTensorDims({5});
- EXPECT_THAT(d.sizes, ElementsAre(5, 1, 1, 1));
- EXPECT_THAT(d.strides, ElementsAre(1, 5, 5, 5));
+TEST(TensorTest, GetTensorShape1D) {
+ RuntimeShape d = GetTensorShape({5});
+ EXPECT_THAT(
+ std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()),
+ ElementsAre(5));
}
} // namespace
diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.cc b/tensorflow/contrib/lite/kernels/internal/test_util.cc
index 9b1fd9b344..75d568ae3a 100644
--- a/tensorflow/contrib/lite/kernels/internal/test_util.cc
+++ b/tensorflow/contrib/lite/kernels/internal/test_util.cc
@@ -19,41 +19,24 @@ limitations under the License.
namespace tflite {
-Dims<4> MakeDimsForInference(int depth, int width, int height, int batch) {
- Dims<4> result;
- int cum_prod = 1;
-
- result.sizes[0] = depth;
- result.strides[0] = cum_prod;
- cum_prod *= result.sizes[0];
-
- result.sizes[1] = width;
- result.strides[1] = cum_prod;
- cum_prod *= result.sizes[1];
-
- result.sizes[2] = height;
- result.strides[2] = cum_prod;
- cum_prod *= result.sizes[2];
-
- result.sizes[3] = batch;
- result.strides[3] = cum_prod;
-
- return result;
-}
-
// this is a copied from an internal function in propagate_fixed_sizes.cc
-bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width,
- int filter_height, int stride, PaddingType padding_type,
- Dims<4>* output_dims, int* pad_width, int* pad_height) {
- const int input_width = ArraySize(input_dims, 1);
- const int input_height = ArraySize(input_dims, 2);
- const int batch = ArraySize(input_dims, 3);
+bool ComputeConvSizes(const RuntimeShape& input_shape, int output_depth,
+ int filter_width, int filter_height, int stride,
+ int dilation_width_factor, int dilation_height_factor,
+ PaddingType padding_type, RuntimeShape* output_shape,
+ int* pad_width, int* pad_height) {
+ const int input_width = input_shape.Dims(2);
+ const int input_height = input_shape.Dims(1);
+ const int batch = input_shape.Dims(0);
+
+ int dilated_filter_width = dilation_width_factor * (filter_width - 1) + 1;
+ int dilated_filter_height = dilation_height_factor * (filter_height - 1) + 1;
int output_height = 0;
int output_width = 0;
if (padding_type == PaddingType::kValid) {
- output_height = (input_height + stride - filter_height) / stride;
- output_width = (input_width + stride - filter_width) / stride;
+ output_height = (input_height + stride - dilated_filter_height) / stride;
+ output_width = (input_width + stride - dilated_filter_width) / stride;
} else if (padding_type == PaddingType::kSame) {
output_height = (input_height + stride - 1) / stride;
output_width = (input_width + stride - 1) / stride;
@@ -65,11 +48,14 @@ bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width,
return false;
}
- *pad_height =
- ((output_height - 1) * stride + filter_height - input_height) / 2;
- *pad_width = ((output_width - 1) * stride + filter_width - input_width) / 2;
- *output_dims =
- MakeDimsForInference(output_depth, output_width, output_height, batch);
+ *pad_height = std::max(
+ 0, ((output_height - 1) * stride + dilated_filter_height - input_height) /
+ 2);
+ *pad_width = std::max(
+ 0,
+ ((output_width - 1) * stride + dilated_filter_width - input_width) / 2);
+
+ output_shape->BuildFrom({batch, output_height, output_width, output_depth});
return true;
}
diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.h b/tensorflow/contrib/lite/kernels/internal/test_util.h
index 26078cef49..e4a383bedf 100644
--- a/tensorflow/contrib/lite/kernels/internal/test_util.h
+++ b/tensorflow/contrib/lite/kernels/internal/test_util.h
@@ -26,13 +26,12 @@ limitations under the License.
namespace tflite {
-// Creates a Dims struct from a set of dimensions.
-Dims<4> MakeDimsForInference(int depth, int width, int height, int batch);
-
// Computes output and padding dimensions.
-bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width,
- int filter_height, int stride, PaddingType padding_type,
- Dims<4>* output_dims, int* pad_width, int* pad_height);
+bool ComputeConvSizes(const RuntimeShape& input_shape, int output_depth,
+ int filter_width, int filter_height, int stride,
+ int dilation_width_factor, int dilation_height_factor,
+ PaddingType padding_type, RuntimeShape* output_shape,
+ int* pad_width, int* pad_height);
// Returns a mt19937 random engine.
std::mt19937& RandomEngine();
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index f6636acc58..b39347758a 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -15,9 +15,10 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
+#include <algorithm>
#include <cstring>
-#include <iterator>
+#include "absl/base/macros.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
namespace tflite {
@@ -125,7 +126,11 @@ class RuntimeShape {
explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {
if (dimensions_count > kMaxSmallSize) {
+#ifdef TF_LITE_STATIC_MEMORY
+ TFLITE_CHECK(false && "No shape resizing supported on this platform");
+#else // TF_LITE_STATIC_MEMORY
dims_pointer_ = new int32[dimensions_count];
+#endif // TF_LITE_STATIC_MEMORY
}
}
@@ -160,7 +165,11 @@ class RuntimeShape {
~RuntimeShape() {
if (size_ > kMaxSmallSize) {
+#ifdef TF_LITE_STATIC_MEMORY
+ TFLITE_CHECK(false && "No shape resizing supported on this platform");
+#else // TF_LITE_STATIC_MEMORY
delete[] dims_pointer_;
+#endif // TF_LITE_STATIC_MEMORY
}
}
@@ -179,20 +188,31 @@ class RuntimeShape {
dims_[i] = val;
}
}
+
inline int32* DimsData() {
return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
}
inline const int32* DimsData() const {
return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
}
+ // The caller must ensure that the shape is no bigger than 4-D.
+ inline const int32* DimsDataUpTo4D() const { return dims_; }
inline void Resize(int dimensions_count) {
if (size_ > kMaxSmallSize) {
+#ifdef TF_LITE_STATIC_MEMORY
+ TFLITE_CHECK(false && "No shape resizing supported on this platform");
+#else // TF_LITE_STATIC_MEMORY
delete[] dims_pointer_;
+#endif // TF_LITE_STATIC_MEMORY
}
size_ = dimensions_count;
if (dimensions_count > kMaxSmallSize) {
+#ifdef TF_LITE_STATIC_MEMORY
+ TFLITE_CHECK(false && "No shape resizing supported on this platform");
+#else // TF_LITE_STATIC_MEMORY
dims_pointer_ = new int32[dimensions_count];
+#endif // TF_LITE_STATIC_MEMORY
}
}
@@ -346,11 +366,12 @@ inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
}
inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
- TFLITE_DCHECK(i0 >= 0 && i0 < shape.Dims(0));
- TFLITE_DCHECK(i1 >= 0 && i1 < shape.Dims(1));
- TFLITE_DCHECK(i2 >= 0 && i2 < shape.Dims(2));
- TFLITE_DCHECK(i3 >= 0 && i3 < shape.Dims(3));
- const int* dims_data = shape.DimsData();
+ TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4);
+ const int* dims_data = shape.DimsDataUpTo4D();
+ TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
+ TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
+ TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
+ TFLITE_DCHECK(i3 >= 0 && i3 < dims_data[3]);
return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
}
@@ -420,7 +441,7 @@ inline int FlatSize(const Dims<N>& dims) {
return flat_size;
}
-// Deprecated. Prefer FlatSize.
+ABSL_DEPRECATED("Prefer FlatSize.")
inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
return FlatSize(dims);
}
@@ -772,6 +793,8 @@ struct DepthwiseParams {
PaddingValues padding_values;
int16 stride_width;
int16 stride_height;
+ int16 dilation_width_factor;
+ int16 dilation_height_factor;
int16 depth_multiplier;
// uint8 inference params.
// TODO(b/65838351): Use smaller types if appropriate.
@@ -852,6 +875,15 @@ struct MeanParams {
int16 axis[4];
};
+struct PackParams {
+ int8 axis;
+ const int32* input_zeropoint;
+ const float* input_scale;
+ uint16 inputs_count;
+ int32 output_zeropoint;
+ float output_scale;
+};
+
struct PadParams {
int8 left_padding_count;
int32 left_padding[4];
@@ -952,6 +984,11 @@ struct TransposeParams {
int32 perm[4];
};
+struct UnpackParams {
+ uint16 num_split;
+ int16 axis;
+};
+
template <typename P>
inline void SetActivationParams(float min, float max, P* params) {
params->float_activation_min = min;
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc
index 08f942c933..503ef28459 100644
--- a/tensorflow/contrib/lite/kernels/kernel_util.cc
+++ b/tensorflow/contrib/lite/kernels/kernel_util.cc
@@ -107,6 +107,9 @@ bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2) {
return TfLiteIntArrayEqual(input1->dims, input2->dims);
}
+// TODO(petewarden): Having macros around this is ugly, look at other strategies
+// before replicating this approach elsewhere.
+#ifndef TF_LITE_STATIC_MEMORY
TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
const TfLiteTensor* input1,
const TfLiteTensor* input2,
@@ -125,5 +128,6 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
*output_shape = shape.release();
return kTfLiteOk;
}
+#endif // TF_LITE_STATIC_MEMORY
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc b/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc
index 1bbea67b93..9739fd4514 100644
--- a/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc
@@ -16,7 +16,7 @@ limitations under the License.
// Layer Normalization LSTM op that applies normalization by mean and standard
// deviation to the activation of the LSTM layers. Please see
// https://arxiv.org/abs/1607.06450 for details.
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc b/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc
index abc229f85a..479f6a7d3c 100644
--- a/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/log_softmax_test.cc b/tensorflow/contrib/lite/kernels/log_softmax_test.cc
index 9a8d35e82c..1acc966cdc 100644
--- a/tensorflow/contrib/lite/kernels/log_softmax_test.cc
+++ b/tensorflow/contrib/lite/kernels/log_softmax_test.cc
@@ -91,8 +91,9 @@ TEST(LogSoftmaxOpTest, CompareWithTFmini) {
std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
auto input_shape = RuntimeShape({batch_size, 1, 1, input_size});
- tflite::reference_ops::LogSoftmax(input_buffer, input_shape,
- output_buffer.get(), input_shape);
+ SoftmaxParams params;
+ tflite::reference_ops::LogSoftmax(params, input_shape, input_buffer,
+ input_shape, output_buffer.get());
std::vector<float> expected;
expected.insert(expected.end(), output_buffer.get(),
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index aaa3ce966e..5b996d00bc 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -893,18 +893,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
activation_out->type == kTfLiteFloat32 &&
concat_temp->type == kTfLiteFloat32 &&
activation_temp->type == kTfLiteFloat32) {
+ tflite::LstmCellParams op_params;
+ // Float LSTM cell does not need parameters to be set: leave untouched.
optimized_ops::LstmCell(
+ op_params,
// Inputs.
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(prev_activation), GetTensorDims(prev_activation),
- GetTensorData<float>(weights), GetTensorDims(weights),
- GetTensorData<float>(bias), GetTensorDims(bias),
- GetTensorData<float>(prev_state), GetTensorDims(prev_state),
+ GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(prev_activation), GetTensorData<float>(prev_activation),
+ GetTensorShape(weights), GetTensorData<float>(weights),
+ GetTensorShape(bias), GetTensorData<float>(bias),
+ GetTensorShape(prev_state), GetTensorData<float>(prev_state),
// Outputs.
- GetTensorData<float>(state_out), GetTensorDims(state_out),
- GetTensorData<float>(activation_out), GetTensorDims(activation_out),
- GetTensorData<float>(concat_temp), GetTensorDims(concat_temp),
- GetTensorData<float>(activation_temp), GetTensorDims(activation_temp));
+ GetTensorShape(state_out), GetTensorData<float>(state_out),
+ GetTensorShape(activation_out), GetTensorData<float>(activation_out),
+ GetTensorShape(concat_temp), GetTensorData<float>(concat_temp),
+ GetTensorShape(activation_temp), GetTensorData<float>(activation_temp));
} else if (input->type == kTfLiteUInt8 &&
prev_activation->type == kTfLiteUInt8 &&
weights->type == kTfLiteUInt8 && bias->type == kTfLiteInt32 &&
@@ -934,20 +937,25 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
int accum_shift;
tflite::QuantizeMultiplier(real_accum_multiplier, &accum_multiplier,
&accum_shift);
+ tflite::LstmCellParams op_params;
+ op_params.weights_zero_point = weights->params.zero_point;
+ op_params.accum_multiplier = accum_multiplier;
+ op_params.accum_shift = accum_shift;
optimized_ops::LstmCell<4>(
+ op_params,
// Inputs.
- GetTensorData<uint8_t>(input), GetTensorDims(input),
- GetTensorData<uint8_t>(prev_activation), GetTensorDims(prev_activation),
- GetTensorData<uint8_t>(weights), GetTensorDims(weights),
- GetTensorData<int32_t>(bias), GetTensorDims(bias),
- GetTensorData<int16_t>(prev_state), GetTensorDims(prev_state),
+ GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(prev_activation),
+ GetTensorData<uint8_t>(prev_activation), GetTensorShape(weights),
+ GetTensorData<uint8_t>(weights), GetTensorShape(bias),
+ GetTensorData<int32_t>(bias), GetTensorShape(prev_state),
+ GetTensorData<int16_t>(prev_state),
// Outputs.
- GetTensorData<int16_t>(state_out), GetTensorDims(state_out),
- GetTensorData<uint8_t>(activation_out), GetTensorDims(activation_out),
- GetTensorData<uint8_t>(concat_temp), GetTensorDims(concat_temp),
- GetTensorData<int16_t>(activation_temp), GetTensorDims(activation_temp),
- weights->params.zero_point, accum_multiplier, accum_shift,
- gemm_context);
+ GetTensorShape(state_out), GetTensorData<int16_t>(state_out),
+ GetTensorShape(activation_out), GetTensorData<uint8_t>(activation_out),
+ GetTensorShape(concat_temp), GetTensorData<uint8_t>(concat_temp),
+ GetTensorShape(activation_temp),
+ GetTensorData<int16_t>(activation_temp), gemm_context);
} else {
context->ReportError(context,
"Unsupported combination of data types for LstmCell");
diff --git a/tensorflow/contrib/lite/kernels/mfcc.cc b/tensorflow/contrib/lite/kernels/mfcc.cc
index 66cf147d75..5153ce5634 100644
--- a/tensorflow/contrib/lite/kernels/mfcc.cc
+++ b/tensorflow/contrib/lite/kernels/mfcc.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/internal/mfcc.h"
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/mfcc_dct.h"
diff --git a/tensorflow/contrib/lite/kernels/mfcc_test.cc b/tensorflow/contrib/lite/kernels/mfcc_test.cc
index c9124adcaf..fe69223222 100644
--- a/tensorflow/contrib/lite/kernels/mfcc_test.cc
+++ b/tensorflow/contrib/lite/kernels/mfcc_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/op_macros.h b/tensorflow/contrib/lite/kernels/op_macros.h
index d66364c4d8..11e814daee 100644
--- a/tensorflow/contrib/lite/kernels/op_macros.h
+++ b/tensorflow/contrib/lite/kernels/op_macros.h
@@ -15,17 +15,55 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_OP_MACROS_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_OP_MACROS_H_
+// If we're on a platform without standard IO functions, fall back to a
+// non-portable function.
+#ifdef TF_LITE_MCU_DEBUG_LOG
+
+// This header is pulled in from the support library at
+// https://github.com/google/stm32_bare_lib
+#include <debug_log.h>
+
+#define DEBUG_LOG(x) \
+ do { \
+ DebugLog(x); \
+ } while (0)
+
+inline void InfiniteLoop() {
+ DEBUG_LOG("HALTED\n");
+ while (1) {
+ }
+}
+#define TFLITE_ASSERT_FALSE InfiniteLoop();
+#define TFLITE_ABORT InfiniteLoop();
+
+#else // TF_LITE_MCU_DEBUG_LOG
+
+#include <cassert>
#include <cstdio>
+#include <cstdlib>
-#define TF_LITE_FATAL(msg) \
- do { \
- fprintf(stderr, "%s\n", (msg)); \
- exit(1); \
+#define DEBUG_LOG(x) \
+ do { \
+ fprintf(stderr, "%s", (x)); \
} while (0)
+
+#define TFLITE_ASSERT_FALSE assert(false)
+#define TFLITE_ABORT abort()
+
+#endif // TF_LITE_MCU_DEBUG_LOG
+
+#define TF_LITE_FATAL(msg) \
+ do { \
+ DEBUG_LOG(msg); \
+ DEBUG_LOG("\nFATAL\n"); \
+ TFLITE_ABORT; \
+ } while (0)
+
#define TF_LITE_ASSERT(x) \
do { \
if (!(x)) TF_LITE_FATAL(#x); \
} while (0)
+
#define TF_LITE_ASSERT_EQ(x, y) \
do { \
if ((x) != (y)) TF_LITE_FATAL(#x " didn't equal " #y); \
diff --git a/tensorflow/contrib/lite/kernels/pack.cc b/tensorflow/contrib/lite/kernels/pack.cc
index 4cb98fdd19..c368582ef7 100644
--- a/tensorflow/contrib/lite/kernels/pack.cc
+++ b/tensorflow/contrib/lite/kernels/pack.cc
@@ -85,9 +85,12 @@ template <typename T>
void PackImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output,
int values_count, int axis) {
VectorOfTensors<T> all_inputs(*context, *node->inputs);
- reference_ops::Pack<T>(RemapDim(NumDimensions(output), axis),
- all_inputs.data(), all_inputs.dims(), values_count,
- GetTensorData<T>(output), GetTensorDims(output));
+ tflite::PackParams op_params;
+ op_params.axis = axis;
+ op_params.inputs_count = values_count;
+
+ reference_ops::Pack<T>(op_params, all_inputs.shapes(), all_inputs.data(),
+ GetTensorShape(output), GetTensorData<T>(output));
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc
index d94d821e87..4732a37a65 100644
--- a/tensorflow/contrib/lite/kernels/reduce.cc
+++ b/tensorflow/contrib/lite/kernels/reduce.cc
@@ -215,7 +215,7 @@ TfLiteStatus PrepareAny(TfLiteContext* context, TfLiteNode* node) {
return PrepareSimple(context, node);
}
-TfLiteStatus PrepareMean(TfLiteContext* context, TfLiteNode* node) {
+TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
// reduce_mean requires a buffer to store intermediate sum result.
@@ -274,7 +274,7 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
} else {
TF_LITE_ENSURE(
context,
- reference_ops::Mean<>(
+ reference_ops::QuantizedMeanOrSum<>(
GetTensorData<uint8_t>(op_context.input),
op_context.input->params.zero_point,
op_context.input->params.scale, op_context.input->dims->data,
@@ -286,7 +286,7 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
GetTensorData<int>(op_context.axis), num_axis,
op_context.params->keep_dims, GetTensorData<int>(temp_index),
GetTensorData<int>(resolved_axis),
- GetTensorData<int>(temp_sum)));
+ GetTensorData<int>(temp_sum), /*compute_sum=*/false));
}
break;
default:
@@ -416,19 +416,57 @@ TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node) {
}
}
+TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
+ OpContext op_context(context, node);
+ const auto& input = op_context.input;
+ const auto& output = op_context.output;
+ if (input->type != kTfLiteUInt8 ||
+ (input->params.scale == output->params.scale &&
+ input->params.zero_point == output->params.zero_point)) {
+ return EvalGeneric<kReference, kSum>(context, node);
+ } else {
+ // Rescaling 8bit reduce sum.
+ int num_axis = static_cast<int>(NumElements(op_context.axis));
+ TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
+ TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
+ TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
+ // Resize the output tensor if the output tensor is dynamic.
+ if (IsDynamicTensor(op_context.output)) {
+ TF_LITE_ENSURE_OK(context,
+ ResizeTempAxis(context, &op_context, resolved_axis));
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum));
+ }
+
+ TF_LITE_ENSURE(
+ context,
+ reference_ops::QuantizedMeanOrSum<>(
+ GetTensorData<uint8_t>(op_context.input),
+ op_context.input->params.zero_point, op_context.input->params.scale,
+ op_context.input->dims->data, op_context.input->dims->size,
+ GetTensorData<uint8_t>(op_context.output),
+ op_context.output->params.zero_point,
+ op_context.output->params.scale, op_context.output->dims->data,
+ op_context.output->dims->size, GetTensorData<int>(op_context.axis),
+ num_axis, op_context.params->keep_dims,
+ GetTensorData<int>(temp_index), GetTensorData<int>(resolved_axis),
+ GetTensorData<int32>(temp_sum), /*compute_sum=*/true));
+ }
+
+ return kTfLiteOk;
+}
} // namespace reduce
TfLiteRegistration* Register_MEAN_REF() {
static TfLiteRegistration r = {reduce::Init, reduce::Free,
- reduce::PrepareMean,
+ reduce::PrepareMeanOrSum,
reduce::EvalMean<reduce::kReference>};
return &r;
}
TfLiteRegistration* Register_SUM_REF() {
- static TfLiteRegistration r = {
- reduce::Init, reduce::Free, reduce::PrepareSimple,
- reduce::EvalGeneric<reduce::kReference, reduce::kSum>};
+ static TfLiteRegistration r = {reduce::Init, reduce::Free,
+ reduce::PrepareMeanOrSum, reduce::EvalSum};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/reduce_test.cc b/tensorflow/contrib/lite/kernels/reduce_test.cc
index 6d289b14d8..fb2ec58ab2 100644
--- a/tensorflow/contrib/lite/kernels/reduce_test.cc
+++ b/tensorflow/contrib/lite/kernels/reduce_test.cc
@@ -488,6 +488,18 @@ TEST(ConstUint8SumOpTest, NotKeepDims) {
ArrayFloatNear({-0.823529, -0.815686}, kQuantizedTolerance)));
}
+TEST(ConstUint8SumOpTest, NotKeepDimsRescaling) {
+ float kQuantizedTolerance = GetTolerance(0.0, 2.0);
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ SumOpConstModel m({TensorType_UINT8, {1, 3, 2}, 0.0, 1.0},
+ {TensorType_UINT8, {2}, 0.0, 2.0}, {1}, {1}, false);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
+ {1.2, 1.2}, kQuantizedTolerance)));
+}
+
TEST(ConstUint8SumOpTest, KeepDims) {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 14296d3a9f..9402105fa7 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -119,12 +119,13 @@ TfLiteRegistration* Register_LOGICAL_NOT();
TfLiteRegistration* Register_UNPACK();
TfLiteRegistration* Register_FLOOR_DIV();
TfLiteRegistration* Register_SQUARE();
+TfLiteRegistration* Register_ZEROS_LIKE();
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
context->ReportError(
context,
"Regular TensorFlow ops are not supported by this interpreter. Make sure "
- "you invoke the Eager delegate before inference.");
+ "you invoke the Flex delegate before inference.");
return kTfLiteError;
}
@@ -135,13 +136,13 @@ const TfLiteRegistration* BuiltinOpResolver::FindOp(tflite::BuiltinOperator op,
const TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op,
int version) const {
- // Return the NULL Op for all ops whose name start with "Eager", allowing
+ // Return the NULL Op for all ops whose name start with "Flex", allowing
// the interpreter to delegate their execution.
- if (IsEagerOp(op)) {
+ if (IsFlexOp(op)) {
static TfLiteRegistration null_op{
nullptr, nullptr, &UnsupportedTensorFlowOp,
nullptr, nullptr, BuiltinOperator_CUSTOM,
- "Eager", 1};
+ "Flex", 1};
return &null_op;
}
return MutableOpResolver::FindOp(op, version);
@@ -157,7 +158,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D());
AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_2D());
AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D());
- AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D());
+ AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D(),
+ /* min_version */ 1,
+ /* max_version */ 2);
AddBuiltin(BuiltinOperator_SVDF, Register_SVDF());
AddBuiltin(BuiltinOperator_RNN, Register_RNN());
AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
@@ -245,6 +248,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE());
+ AddBuiltin(BuiltinOperator_ZEROS_LIKE, Register_ZEROS_LIKE());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/contrib/lite/kernels/relu1_test.cc b/tensorflow/contrib/lite/kernels/relu1_test.cc
index c1e0149c20..b1d25a9f50 100644
--- a/tensorflow/contrib/lite/kernels/relu1_test.cc
+++ b/tensorflow/contrib/lite/kernels/relu1_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc
index 3959502d91..4780a86ee5 100644
--- a/tensorflow/contrib/lite/kernels/select.cc
+++ b/tensorflow/contrib/lite/kernels/select.cc
@@ -70,12 +70,12 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
bool is_rank_one = !HaveSameShapes(input_condition, input_x);
-#define TF_LITE_SELECT(type, op) \
- reference_ops::op(GetTensorData<bool>(input_condition), \
- GetTensorDims(input_condition), \
- GetTensorData<type>(input_x), GetTensorDims(input_x), \
- GetTensorData<type>(input_y), GetTensorDims(input_y), \
- GetTensorData<type>(output), GetTensorDims(output));
+#define TF_LITE_SELECT(type, op) \
+ reference_ops::op(GetTensorShape(input_condition), \
+ GetTensorData<bool>(input_condition), \
+ GetTensorShape(input_x), GetTensorData<type>(input_x), \
+ GetTensorShape(input_y), GetTensorData<type>(input_y), \
+ GetTensorShape(output), GetTensorData<type>(output));
#define TF_LITE_SWITCH(type, op) \
switch (type) { \
diff --git a/tensorflow/contrib/lite/kernels/softmax_test.cc b/tensorflow/contrib/lite/kernels/softmax_test.cc
index 727822f6be..bd66980226 100644
--- a/tensorflow/contrib/lite/kernels/softmax_test.cc
+++ b/tensorflow/contrib/lite/kernels/softmax_test.cc
@@ -93,8 +93,10 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaEq1) {
std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
auto input_shape = RuntimeShape({batch_size, 1, 1, input_size});
- tflite::reference_ops::Softmax(input_buffer, input_shape, beta,
- output_buffer.get(), input_shape);
+ SoftmaxParams params;
+ params.beta = beta;
+ tflite::reference_ops::Softmax(params, input_shape, input_buffer, input_shape,
+ output_buffer.get());
std::vector<float> expected;
expected.insert(expected.end(), output_buffer.get(),
@@ -120,8 +122,10 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) {
std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
auto input_shape = RuntimeShape({batch_size, 1, 1, input_size});
- tflite::reference_ops::Softmax(input_buffer, input_shape, beta,
- output_buffer.get(), input_shape);
+ SoftmaxParams params;
+ params.beta = beta;
+ tflite::reference_ops::Softmax(params, input_shape, input_buffer, input_shape,
+ output_buffer.get());
std::vector<float> expected;
expected.insert(expected.end(), output_buffer.get(),
diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
index 178568e07c..349fa0bd28 100644
--- a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
+++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
@@ -210,8 +210,9 @@ TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) {
&indices_vector));
reference_ops::SparseToDense(indices_vector, GetTensorData<T>(values),
*GetTensorData<T>(default_value),
- GetTensorData<T>(output), GetTensorDims(output),
- value_is_scalar);
+ value_is_scalar, GetTensorShape(output),
+ GetTensorData<T>(output));
+
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc
index 719e2dc606..dab887bf9c 100644
--- a/tensorflow/contrib/lite/kernels/split.cc
+++ b/tensorflow/contrib/lite/kernels/split.cc
@@ -109,25 +109,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (axis_value < 0) {
axis_value += NumDimensions(op_context.input);
}
- axis_value = RemapDim(NumDimensions(op_context.input), axis_value);
// TODO(ahentz): Our usage of VectorOfTensors could be optimized by
// calculating it in Prepare, unless we defer shape calculation.
// TODO(ahentz): We can improve the optimized_ops version to handle other
// cases too.
-#define TF_LITE_SPLIT(scalar) \
- VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \
- if (axis_value == NumDimensions(op_context.input)) { \
- optimized_ops::TensorFlowSplit<FusedActivationFunctionType::kNone, \
- scalar>( \
- GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), NumOutputs(node), all_outputs.data(), \
- all_outputs.dims()); \
- } else { \
- reference_ops::TensorFlowSplit<scalar>( \
- GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), axis_value, NumOutputs(node), \
- all_outputs.data(), all_outputs.dims()); \
+#define TF_LITE_SPLIT(scalar) \
+ VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \
+ tflite::SplitParams op_params; \
+ op_params.num_split = NumOutputs(node); \
+ op_params.axis = axis_value; \
+ if (axis_value == 0) { \
+ optimized_ops::Split(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ all_outputs.shapes(), all_outputs.data()); \
+ } else { \
+ reference_ops::Split(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ all_outputs.shapes(), all_outputs.data()); \
}
switch (op_context.input->type) {
case kTfLiteFloat32: {
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc
index 87ffcc4110..06b36dd196 100644
--- a/tensorflow/contrib/lite/kernels/strided_slice.cc
+++ b/tensorflow/contrib/lite/kernels/strided_slice.cc
@@ -57,17 +57,6 @@ struct StridedSliceContext {
int dims;
};
-// Reverse order of bits in the mask to match the expected order in kernel
-inline int ReverseMaskBits(int mask, int num_dimensions) {
- int out = 0;
- for (int dim = 0; dim < num_dimensions; dim++) {
- out <<= 1;
- out += (mask & 1);
- mask >>= 1;
- }
- return out;
-}
-
// This Op only supports 1-4D cases and since we use the reference 4D
// implementation, the 1-3D tensors are mapped to 4D.
const int kMaxDim = 4;
@@ -198,30 +187,31 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
std::vector<int32_t> stops;
std::vector<int32_t> strides;
- for (int idx = op_context.dims - 1; idx >= 0; --idx) {
- starts.emplace_back(GetTensorData<int32_t>(op_context.begin)[idx]);
- stops.emplace_back(GetTensorData<int32_t>(op_context.end)[idx]);
- strides.emplace_back(GetTensorData<int32_t>(op_context.strides)[idx]);
- }
-
for (int i = op_context.dims; i < kMaxDim; i++) {
starts.emplace_back(0);
stops.emplace_back(1);
strides.emplace_back(1);
}
- int begin_mask =
- ReverseMaskBits(op_context.params->begin_mask, op_context.dims);
- int end_mask = ReverseMaskBits(op_context.params->end_mask, op_context.dims);
- int shrink_axis_mask =
- ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims);
-
-#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
- kernel_type::StridedSlice( \
- GetTensorData<data_type>(op_context.input), \
- GetTensorDims(op_context.input), begin_mask, end_mask, shrink_axis_mask, \
- starts, stops, strides, GetTensorData<data_type>(op_context.output), \
- GetTensorDims(op_context.output))
+ for (int idx = 0; idx < op_context.dims; ++idx) {
+ starts.emplace_back(GetTensorData<int32_t>(op_context.begin)[idx]);
+ stops.emplace_back(GetTensorData<int32_t>(op_context.end)[idx]);
+ strides.emplace_back(GetTensorData<int32_t>(op_context.strides)[idx]);
+ }
+
+ int begin_mask = op_context.params->begin_mask << (4 - op_context.dims);
+ int end_mask = op_context.params->end_mask << (4 - op_context.dims);
+ int shrink_axis_mask = op_context.params->shrink_axis_mask
+ << (4 - op_context.dims);
+ TF_LITE_ENSURE_EQ(context, starts.size(), 4);
+ auto op_params = ::tflite::strided_slice::BuildStridedSliceParams(
+ begin_mask, end_mask, shrink_axis_mask, starts, stops, strides);
+
+#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
+ kernel_type::StridedSlice(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<data_type>(op_context.input), \
+ GetTensorShape(op_context.output), \
+ GetTensorData<data_type>(op_context.output))
switch (op_context.input->type) {
case kTfLiteFloat32:
diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc
index 0fdb0a3935..05a7c23ba1 100644
--- a/tensorflow/contrib/lite/kernels/test_util.cc
+++ b/tensorflow/contrib/lite/kernels/test_util.cc
@@ -122,7 +122,7 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
- interpreter_->ResetVariableTensorsToZero();
+ interpreter_->ResetVariableTensors();
}
void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); }
diff --git a/tensorflow/contrib/lite/kernels/transpose.cc b/tensorflow/contrib/lite/kernels/transpose.cc
index 95359962e0..e42a30420b 100644
--- a/tensorflow/contrib/lite/kernels/transpose.cc
+++ b/tensorflow/contrib/lite/kernels/transpose.cc
@@ -92,26 +92,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
}
- // Reverse the permuted axes and convert to 4D due to the way Dims are
- // constructed in GetTensorDims.
const int* perm_data = GetTensorData<int32_t>(op_context.perm);
const int size = op_context.perm->dims->data[0];
- const int kOutputDimensionNum = 4;
- int reversed_perm[kOutputDimensionNum];
-
- for (int output_k = 0, input_k = size - 1; output_k < size;
- ++output_k, --input_k) {
- reversed_perm[output_k] = size - perm_data[input_k] - 1;
- }
- for (int k = size; k < kOutputDimensionNum; ++k) {
- reversed_perm[k] = k;
+ TransposeParams params;
+ params.perm_count = size;
+ for (int i = 0; i < size; ++i) {
+ params.perm[i] = perm_data[i];
}
#define TF_LITE_TRANSPOSE(type, scalar) \
- type::Transpose(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), \
- GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output), reversed_perm)
+ type::Transpose(params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ GetTensorShape(op_context.output), \
+ GetTensorData<scalar>(op_context.output))
switch (op_context.input->type) {
case kTfLiteFloat32:
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc
index 6f2d98ede8..1c4a5ee91d 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc
@@ -69,7 +69,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 4);
- // Currenlty only supports float32.
+ // Currently only supports float32.
const TfLiteType data_type = input->type;
TF_LITE_ENSURE(context, data_type == kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, output->type, data_type);
@@ -117,19 +117,26 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Currently only support float32.
switch (input->type) {
- case kTfLiteFloat32:
+ case kTfLiteFloat32: {
+ tflite::ConvParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = padding_size.width;
+ op_params.padding_values.height = padding_size.height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+
reference_ops::TransposeConv(
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(weights), GetTensorDims(weights), stride_width,
- stride_height, padding_size.width, padding_size.height,
- GetTensorData<float>(output), GetTensorDims(output),
+ op_params, GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(weights), GetTensorData<float>(weights),
+ GetTensorShape(output), GetTensorData<float>(output),
// Last two args specify im2col which reference_ops ignores.
// (Note this does not lead to a performance regression, as the
// previous optimized version was just a copy of the reference code.)
// TODO(b/110208176): Allocate im2col tensors and switch to
// optimized_ops.
- GetTensorData<float>(output), GetTensorDims(output));
+ GetTensorShape(output), GetTensorData<float>(output));
break;
+ }
default:
context->ReportError(context, "Type %d, not currently supported.",
input->type);
diff --git a/tensorflow/contrib/lite/kernels/transpose_test.cc b/tensorflow/contrib/lite/kernels/transpose_test.cc
index 337bc144b9..79ef0a7c56 100644
--- a/tensorflow/contrib/lite/kernels/transpose_test.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_test.cc
@@ -51,21 +51,21 @@ void RunTestPermutation(const std::vector<int>& shape,
reversed_perms[k] = k;
}
- // Make input and output dims (i.e. reversed shape and dest_shape).
- Dims<4> input_dims = GetTensorDims(shape);
- Dims<4> output_dims;
- for (int i = 0; i < 4; i++) {
- output_dims.sizes[i] = input_dims.sizes[reversed_perms[i]];
+ // Make input and output shapes.
+ const RuntimeShape input_shape = GetTensorShape(shape);
+ RuntimeShape output_shape(perms.size());
+ for (int i = 0; i < perms.size(); i++) {
+ output_shape.SetDim(i, input_shape.Dims(perms[i]));
}
- output_dims.strides[0] = 1;
- for (int k = 1; k < 4; k++) {
- output_dims.strides[k] =
- output_dims.strides[k - 1] * output_dims.sizes[k - 1];
+
+ TransposeParams params;
+ params.perm_count = perms.size();
+ for (int i = 0; i < perms.size(); ++i) {
+ params.perm[i] = perms[i];
}
- reference_ops::Transpose<float>(input.data(), input_dims,
- input_transposed->data(), output_dims,
- reversed_perms);
+ reference_ops::Transpose<float>(params, input_shape, input.data(),
+ output_shape, input_transposed->data());
}
TEST(TransposeTest, TestRefOps1D) {
diff --git a/tensorflow/contrib/lite/kernels/unpack.cc b/tensorflow/contrib/lite/kernels/unpack.cc
index 9ff06f8331..a7d3a9bc76 100644
--- a/tensorflow/contrib/lite/kernels/unpack.cc
+++ b/tensorflow/contrib/lite/kernels/unpack.cc
@@ -88,10 +88,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <typename T>
void UnpackImpl(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input, int output_count, int axis) {
+ tflite::UnpackParams op_params;
+ op_params.axis = axis;
+ op_params.num_split = output_count;
VectorOfTensors<T> all_outputs(*context, *node->outputs);
- reference_ops::Unpack<T>(axis, GetTensorData<T>(input), GetTensorDims(input),
- NumDimensions(input), output_count,
- all_outputs.data(), **all_outputs.dims());
+ reference_ops::Unpack<T>(op_params, GetTensorShape(input),
+ GetTensorData<T>(input), **all_outputs.shapes(),
+ all_outputs.data());
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
diff --git a/tensorflow/contrib/lite/kernels/zeros_like.cc b/tensorflow/contrib/lite/kernels/zeros_like.cc
new file mode 100644
index 0000000000..cce5240a9b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/zeros_like.cc
@@ -0,0 +1,73 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace zeros_like {
+
+constexpr int kInputTensor = 0;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ output->type = input->type;
+
+ return context->ResizeTensor(context, output,
+ TfLiteIntArrayCopy(input->dims));
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ const int num_elements = NumElements(input);
+ switch (input->type) {
+ case kTfLiteInt64:
+ memset(GetTensorData<int64_t>(output), 0, num_elements * sizeof(int64_t));
+ break;
+ case kTfLiteInt32:
+ memset(GetTensorData<int32_t>(output), 0, num_elements * sizeof(int32_t));
+ break;
+ case kTfLiteFloat32:
+ memset(GetTensorData<float>(output), 0, num_elements * sizeof(float));
+ break;
+ default:
+ context->ReportError(context,
+ "ZerosLike only currently supports int64, int32, "
+ "and float32, got %d.",
+ input->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace zeros_like
+
+TfLiteRegistration* Register_ZEROS_LIKE() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ zeros_like::Prepare, zeros_like::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/zeros_like_test.cc b/tensorflow/contrib/lite/kernels/zeros_like_test.cc
new file mode 100644
index 0000000000..d3382d1d5b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/zeros_like_test.cc
@@ -0,0 +1,78 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class ZerosLikeOpModel : public SingleOpModel {
+ public:
+ explicit ZerosLikeOpModel(const TensorData& input) {
+ input_ = AddInput(input);
+ output_ = AddOutput(input);
+ SetBuiltinOp(BuiltinOperator_ZEROS_LIKE, BuiltinOptions_ZerosLikeOptions,
+ CreateZerosLikeOptions(builder_).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ int input() { return input_; }
+ int output() { return output_; }
+
+ protected:
+ int input_;
+ int output_;
+};
+
+TEST(ZerosLikeOpModel, ZerosLikeFloat) {
+ ZerosLikeOpModel m({TensorType_FLOAT32, {2, 3}});
+ m.PopulateTensor<float>(m.input(), {-2.0, -1.0, 0.0, 1.0, 2.0, 3.0});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray({0.0, 0.0, 0.0, 0.0, 0.0, 0.0}));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({2, 3}));
+}
+
+TEST(ZerosLikeOpModel, ZerosLikeInt32) {
+ ZerosLikeOpModel m({TensorType_INT32, {1, 2, 2, 1}});
+ m.PopulateTensor<int32_t>(m.input(), {-2, -1, 0, 3});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
+ ElementsAreArray({0, 0, 0, 0}));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 2, 2, 1}));
+}
+
+TEST(ZerosLikeOpModel, ZerosLikeInt64) {
+ ZerosLikeOpModel m({TensorType_INT64, {1, 2, 2, 1}});
+ m.PopulateTensor<int64_t>(m.input(), {-2, -1, 0, 3});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<int64_t>(m.output()),
+ ElementsAreArray({0, 0, 0, 0}));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 2, 2, 1}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 6311d60b91..d50c345194 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -27,8 +27,8 @@ limitations under the License.
#ifndef TFLITE_MCU
#include "tensorflow/contrib/lite/nnapi_delegate.h"
#endif
-#if defined(TFLITE_EXTENDED)
-#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
+#if defined(TFLITE_FLEX)
+#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
#endif
#include "tensorflow/contrib/lite/version.h"
@@ -189,6 +189,13 @@ std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
return ret;
}
+// Used to determine how the op data parsing function creates its working space.
+class MallocDataAllocator : public BuiltinDataAllocator {
+ public:
+ void* Allocate(size_t size) override { return malloc(size); }
+ void Deallocate(void* data) override { free(data); }
+};
+
} // namespace
TfLiteStatus InterpreterBuilder::ParseNodes(
@@ -234,8 +241,9 @@ TfLiteStatus InterpreterBuilder::ParseNodes(
op->custom_options()->size(), nullptr, registration);
} else {
void* builtin_data = nullptr;
- TF_LITE_ENSURE_STATUS(
- ParseOpData(op, op_type, error_reporter_, &builtin_data));
+ MallocDataAllocator malloc_allocator;
+ TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_,
+ &malloc_allocator, &builtin_data));
interpreter->AddNodeWithParameters(
FlatBufferIntArrayToVector(op->inputs()),
FlatBufferIntArrayToVector(op->outputs()), nullptr, 0, builtin_data,
@@ -442,8 +450,8 @@ TfLiteStatus InterpreterBuilder::operator()(
}
(**interpreter).SetVariables(std::move(variables));
-#if defined(TFLITE_EXTENDED)
- if (auto delegate = EagerDelegate::Create()) {
+#if defined(TFLITE_FLEX)
+ if (auto delegate = FlexDelegate::Create()) {
(**interpreter)
.ModifyGraphWithDelegate(std::move(delegate),
/*allow_dynamic_tensors=*/true);
diff --git a/tensorflow/contrib/lite/mutable_op_resolver.cc b/tensorflow/contrib/lite/mutable_op_resolver.cc
index d7c0181720..a36404399b 100644
--- a/tensorflow/contrib/lite/mutable_op_resolver.cc
+++ b/tensorflow/contrib/lite/mutable_op_resolver.cc
@@ -30,7 +30,7 @@ const TfLiteRegistration* MutableOpResolver::FindOp(const char* op,
}
void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
- TfLiteRegistration* registration,
+ const TfLiteRegistration* registration,
int min_version, int max_version) {
for (int version = min_version; version <= max_version; ++version) {
TfLiteRegistration new_registration = *registration;
@@ -43,7 +43,7 @@ void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
}
void MutableOpResolver::AddCustom(const char* name,
- TfLiteRegistration* registration,
+ const TfLiteRegistration* registration,
int min_version, int max_version) {
for (int version = min_version; version <= max_version; ++version) {
TfLiteRegistration new_registration = *registration;
@@ -55,4 +55,15 @@ void MutableOpResolver::AddCustom(const char* name,
}
}
+void MutableOpResolver::AddAll(const MutableOpResolver& other) {
+ // map::insert does not replace existing elements, and map::insert_or_assign
+ // wasn't added until C++17.
+ for (const auto& other_builtin : other.builtins_) {
+ builtins_[other_builtin.first] = other_builtin.second;
+ }
+ for (const auto& other_custom_op : other.custom_ops_) {
+ custom_ops_[other_custom_op.first] = other_custom_op.second;
+ }
+}
+
} // namespace tflite
diff --git a/tensorflow/contrib/lite/mutable_op_resolver.h b/tensorflow/contrib/lite/mutable_op_resolver.h
index c319041e9b..efd6cfac2a 100644
--- a/tensorflow/contrib/lite/mutable_op_resolver.h
+++ b/tensorflow/contrib/lite/mutable_op_resolver.h
@@ -57,10 +57,12 @@ class MutableOpResolver : public OpResolver {
const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
int version) const override;
const TfLiteRegistration* FindOp(const char* op, int version) const override;
- void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
- int min_version = 1, int max_version = 1);
- void AddCustom(const char* name, TfLiteRegistration* registration,
+ void AddBuiltin(tflite::BuiltinOperator op,
+ const TfLiteRegistration* registration, int min_version = 1,
+ int max_version = 1);
+ void AddCustom(const char* name, const TfLiteRegistration* registration,
int min_version = 1, int max_version = 1);
+ void AddAll(const MutableOpResolver& other);
private:
typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey;
diff --git a/tensorflow/contrib/lite/mutable_op_resolver_test.cc b/tensorflow/contrib/lite/mutable_op_resolver_test.cc
index db690eaab9..b70c703839 100644
--- a/tensorflow/contrib/lite/mutable_op_resolver_test.cc
+++ b/tensorflow/contrib/lite/mutable_op_resolver_test.cc
@@ -36,6 +36,20 @@ TfLiteRegistration* GetDummyRegistration() {
return &registration;
}
+TfLiteStatus Dummy2Invoke(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+TfLiteRegistration* GetDummy2Registration() {
+ static TfLiteRegistration registration = {
+ .init = nullptr,
+ .free = nullptr,
+ .prepare = nullptr,
+ .invoke = Dummy2Invoke,
+ };
+ return &registration;
+}
+
TEST(MutableOpResolverTest, FinOp) {
MutableOpResolver resolver;
resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
@@ -119,6 +133,26 @@ TEST(MutableOpResolverTest, FindCustomOpWithUnsupportedVersion) {
EXPECT_EQ(found_registration, nullptr);
}
+TEST(MutableOpResolverTest, AddAll) {
+ MutableOpResolver resolver1;
+ resolver1.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
+ resolver1.AddBuiltin(BuiltinOperator_MUL, GetDummy2Registration());
+
+ MutableOpResolver resolver2;
+ resolver2.AddBuiltin(BuiltinOperator_SUB, GetDummyRegistration());
+ resolver2.AddBuiltin(BuiltinOperator_ADD, GetDummy2Registration());
+
+ // resolver2's ADD op should replace resolver1's ADD op, while augmenting
+ // non-overlapping ops.
+ resolver1.AddAll(resolver2);
+ ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->invoke,
+ GetDummy2Registration()->invoke);
+ ASSERT_EQ(resolver1.FindOp(BuiltinOperator_MUL, 1)->invoke,
+ GetDummy2Registration()->invoke);
+ ASSERT_EQ(resolver1.FindOp(BuiltinOperator_SUB, 1)->invoke,
+ GetDummyRegistration()->invoke);
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index f814b90d66..f23a0ccb80 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -512,6 +512,10 @@ TfLiteStatus AddOpsAndParams(
nn_op_type = ANEURALNETWORKS_FULLY_CONNECTED;
break;
case tflite::BuiltinOperator_RESHAPE:
+ if (node.inputs->size != 2) {
+ logError("NNAPI only supports 2-input RESHAPE");
+ return kTfLiteError;
+ }
nn_op_type = ANEURALNETWORKS_RESHAPE;
// add_reshape_params(node.builtin_data);
break;
@@ -673,6 +677,8 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_FLOOR_DIV:
case tflite::BuiltinOperator_REDUCE_ANY:
case tflite::BuiltinOperator_SQUARE:
+ case tflite::BuiltinOperator_ZEROS_LIKE:
+ case tflite::BuiltinOperator_FILL:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;
diff --git a/tensorflow/contrib/lite/optional_debug_tools.cc b/tensorflow/contrib/lite/optional_debug_tools.cc
index f1f025f777..64ba2d8baa 100644
--- a/tensorflow/contrib/lite/optional_debug_tools.cc
+++ b/tensorflow/contrib/lite/optional_debug_tools.cc
@@ -25,7 +25,7 @@ void PrintIntVector(const std::vector<int>& v) {
void PrintTfLiteIntVector(const TfLiteIntArray* v) {
if (!v) {
- printf(" (null)");
+ printf(" (null)\n");
return;
}
for (int k = 0; k < v->size; k++) {
@@ -99,8 +99,12 @@ void PrintInterpreterState(Interpreter* interpreter) {
interpreter->node_and_registration(node_index);
const TfLiteNode& node = node_and_reg->first;
const TfLiteRegistration& reg = node_and_reg->second;
- printf("Node %3d Operator Builtin Code %3d\n", node_index,
- reg.builtin_code);
+ if (reg.custom_name != nullptr) {
+ printf("Node %3d Operator Custom Name %s\n", node_index, reg.custom_name);
+ } else {
+ printf("Node %3d Operator Builtin Code %3d\n", node_index,
+ reg.builtin_code);
+ }
printf(" Inputs:");
PrintTfLiteIntVector(node.inputs);
printf(" Outputs:");
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 57e1290e07..916788f215 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -144,7 +144,7 @@ py_library(
name = "convert_saved_model",
srcs = ["convert_saved_model.py"],
srcs_version = "PY2AND3",
- visibility = ["//visibility:public"],
+ visibility = ["//tensorflow/contrib/lite:__subpackages__"],
deps = [
":convert",
"//tensorflow/contrib/saved_model:saved_model_py",
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index 1f48a826d4..613a1530f7 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -67,12 +67,12 @@ class ConverterMode(enum.Enum):
# Convert model using TOCO such that only unsupported operations are
# represented as TensorFlow ops.
# WARNING: Experimental interface, subject to change.
- TOCO_EXTENDED = "TOCO_EXTENDED"
+ TOCO_FLEX = "TOCO_FLEX"
# Convert model using TOCO such that all operations are represented as
# TensorFlow ops.
# WARNING: Experimental interface, subject to change.
- TOCO_EXTENDED_ALL = "TOCO_EXTENDED_ALL"
+ TOCO_FLEX_ALL = "TOCO_FLEX_ALL"
def __str__(self):
return self.value
@@ -240,11 +240,11 @@ def build_toco_convert_protos(input_tensors,
if dump_graphviz_dir:
toco.dump_graphviz_dir = dump_graphviz_dir
toco.dump_graphviz_include_video = dump_graphviz_video
- if converter_mode == ConverterMode.TOCO_EXTENDED:
- toco.allow_eager_ops = True
- elif converter_mode == ConverterMode.TOCO_EXTENDED_ALL:
- toco.allow_eager_ops = True
- toco.force_eager_ops = True
+ if converter_mode == ConverterMode.TOCO_FLEX:
+ toco.allow_flex_ops = True
+ elif converter_mode == ConverterMode.TOCO_FLEX_ALL:
+ toco.allow_flex_ops = True
+ toco.force_flex_ops = True
model = _model_flags_pb2.ModelFlags()
model.change_concat_input_ranges = change_concat_input_ranges
@@ -343,13 +343,14 @@ def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
return data
-@deprecation.deprecated(None, "Use `lite.TocoConverter` instead.")
+@deprecation.deprecated(None, "Use `lite.TFLiteConverter` instead.")
def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
- """"Convert a model using TOCO.
+ """Convert a model using TOCO.
Typically this function is used to convert from TensorFlow GraphDef to TFLite.
Conversion can be customized by providing arguments that are forwarded to
- `build_toco_convert_protos` (see documentation for details).
+ `build_toco_convert_protos` (see documentation for details). This function has
+ been deprecated. Please use `lite.TFLiteConverter` instead.
Args:
input_data: Input data (i.e. often `sess.graph_def`),
diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py
index 1553464b9f..d18b60d0ea 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model.py
@@ -44,7 +44,7 @@ def _log_tensor_details(tensor_info):
dtype)
-def _get_meta_graph_def(saved_model_dir, tag_set):
+def get_meta_graph_def(saved_model_dir, tag_set):
"""Validate saved_model and extract MetaGraphDef.
Args:
@@ -61,7 +61,7 @@ def _get_meta_graph_def(saved_model_dir, tag_set):
return loader.load(sess, tag_set, saved_model_dir)
-def _get_signature_def(meta_graph, signature_key):
+def get_signature_def(meta_graph, signature_key):
"""Get the signature def from meta_graph with given signature_key.
Args:
@@ -86,7 +86,7 @@ def _get_signature_def(meta_graph, signature_key):
return signature_def_map[signature_key]
-def _get_inputs_outputs(signature_def):
+def get_inputs_outputs(signature_def):
"""Get inputs and outputs from SignatureDef.
Args:
@@ -236,9 +236,9 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
input_arrays or output_arrays are not valid.
"""
# Read SignatureDef.
- meta_graph = _get_meta_graph_def(saved_model_dir, tag_set)
- signature_def = _get_signature_def(meta_graph, signature_key)
- inputs, outputs = _get_inputs_outputs(signature_def)
+ meta_graph = get_meta_graph_def(saved_model_dir, tag_set)
+ signature_def = get_signature_def(meta_graph, signature_key)
+ inputs, outputs = get_inputs_outputs(signature_def)
# Check SavedModel for assets directory.
collection_def = meta_graph.collection_def
diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py
index 1be61fe053..5700bf7892 100644
--- a/tensorflow/contrib/lite/python/interpreter.py
+++ b/tensorflow/contrib/lite/python/interpreter.py
@@ -253,5 +253,5 @@ class Interpreter(object):
self._ensure_safe()
self._interpreter.Invoke()
- def reset_all_variables_to_zero(self):
- return self._interpreter.ResetVariableTensorsToZero()
+ def reset_all_variables(self):
+ return self._interpreter.ResetVariableTensors()
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index 9ab05f3068..418f19a179 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -466,9 +466,9 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
error_msg);
}
-PyObject* InterpreterWrapper::ResetVariableTensorsToZero() {
+PyObject* InterpreterWrapper::ResetVariableTensors() {
TFLITE_PY_ENSURE_VALID_INTERPRETER();
- TFLITE_PY_CHECK(interpreter_->ResetVariableTensorsToZero());
+ TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
Py_RETURN_NONE;
}
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
index 641dd93db5..f5ca81e62a 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
@@ -65,7 +65,7 @@ class InterpreterWrapper {
PyObject* TensorQuantization(int i) const;
PyObject* SetTensor(int i, PyObject* value);
PyObject* GetTensor(int i) const;
- PyObject* ResetVariableTensorsToZero();
+ PyObject* ResetVariableTensors();
// Returns a reference to tensor index i as a numpy array. The base_object
// should be the interpreter object providing the memory.
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 2be24455d8..09365f101f 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -17,6 +17,7 @@
EXPERIMENTAL: APIs here are unstable and likely to change without notice.
@@TocoConverter
+@@TFLiteConverter
@@toco_convert
@@toco_convert_protos
@@Interpreter
@@ -62,9 +63,10 @@ from tensorflow.python.framework.importer import import_graph_def as _import_gra
from tensorflow.python.lib.io import file_io as _file_io
from tensorflow.python.saved_model import signature_constants as _signature_constants
from tensorflow.python.saved_model import tag_constants as _tag_constants
+from tensorflow.python.util import deprecation as _deprecation
-class TocoConverter(object):
+class TFLiteConverter(object):
"""Convert a TensorFlow model into `output_format` using TOCO.
This is used to convert from a TensorFlow GraphDef or SavedModel into either a
@@ -121,22 +123,22 @@ class TocoConverter(object):
```python
# Converting a GraphDef from session.
- converter = lite.TocoConverter.from_session(sess, in_tensors, out_tensors)
+ converter = lite.TFLiteConverter.from_session(sess, in_tensors, out_tensors)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
# Converting a GraphDef from file.
- converter = lite.TocoConverter.from_frozen_graph(
+ converter = lite.TFLiteConverter.from_frozen_graph(
graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
# Converting a SavedModel.
- converter = lite.TocoConverter.from_saved_model(saved_model_dir)
+ converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
# Converting a tf.keras model.
- converter = lite.TocoConverter.from_keras_model_file(keras_model)
+ converter = lite.TFLiteConverter.from_keras_model_file(keras_model)
tflite_model = converter.convert()
```
"""
@@ -147,10 +149,9 @@ class TocoConverter(object):
output_tensors,
input_arrays_with_shape=None,
output_arrays=None):
- """Constructor for TocoConverter.
+ """Constructor for TFLiteConverter.
Args:
-
graph_def: Frozen TensorFlow GraphDef.
input_tensors: List of input tensors. Type and shape are computed using
`foo.get_shape()` and `foo.dtype`.
@@ -158,8 +159,8 @@ class TocoConverter(object):
input_arrays_with_shape: Tuple of strings representing input tensor names
and list of integers representing input shapes
(e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
- into TensorFlow and when `input_tensors` and `output_tensors` are None.
- (default None)
+ into TensorFlow and when `input_tensors` and `output_tensors` are
+ None. (default None)
output_arrays: List of output tensors to freeze graph with. Use only when
graph cannot be loaded into TensorFlow and when `input_tensors` and
`output_tensors` are None. (default None)
@@ -195,7 +196,7 @@ class TocoConverter(object):
@classmethod
def from_session(cls, sess, input_tensors, output_tensors):
- """Creates a TocoConverter class from a TensorFlow Session.
+ """Creates a TFLiteConverter class from a TensorFlow Session.
Args:
sess: TensorFlow Session.
@@ -204,7 +205,7 @@ class TocoConverter(object):
output_tensors: List of output tensors (only .name is used from this).
Returns:
- TocoConverter class.
+ TFLiteConverter class.
"""
graph_def = _freeze_graph(sess, output_tensors)
return cls(graph_def, input_tensors, output_tensors)
@@ -215,7 +216,7 @@ class TocoConverter(object):
input_arrays,
output_arrays,
input_shapes=None):
- """Creates a TocoConverter class from a file containing a frozen GraphDef.
+ """Creates a TFLiteConverter class from a file containing a frozen GraphDef.
Args:
graph_def_file: Full filepath of file containing frozen GraphDef.
@@ -224,10 +225,10 @@ class TocoConverter(object):
input_shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" :
- None}). (default None)
+ None}). (default None)
Returns:
- TocoConverter class.
+ TFLiteConverter class.
Raises:
IOError:
@@ -310,7 +311,7 @@ class TocoConverter(object):
output_arrays=None,
tag_set=None,
signature_key=None):
- """Creates a TocoConverter class from a SavedModel.
+ """Creates a TFLiteConverter class from a SavedModel.
Args:
saved_model_dir: SavedModel directory to convert.
@@ -319,7 +320,7 @@ class TocoConverter(object):
input_shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" :
- None}). (default None)
+ None}). (default None)
output_arrays: List of output tensors to freeze graph with. Uses output
arrays from SignatureDef when none are provided. (default None)
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
@@ -328,7 +329,7 @@ class TocoConverter(object):
(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)
Returns:
- TocoConverter class.
+ TFLiteConverter class.
"""
if tag_set is None:
tag_set = set([_tag_constants.SERVING])
@@ -346,7 +347,7 @@ class TocoConverter(object):
input_arrays=None,
input_shapes=None,
output_arrays=None):
- """Creates a TocoConverter class from a tf.keras model file.
+ """Creates a TFLiteConverter class from a tf.keras model file.
Args:
model_file: Full filepath of HDF5 file containing the tf.keras model.
@@ -355,12 +356,12 @@ class TocoConverter(object):
input_shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" :
- None}). (default None)
+ None}). (default None)
output_arrays: List of output tensors to freeze graph with. Uses output
arrays from SignatureDef when none are provided. (default None)
Returns:
- TocoConverter class.
+ TFLiteConverter class.
"""
_keras.backend.clear_session()
_keras.backend.set_learning_phase(False)
@@ -502,6 +503,59 @@ class TocoConverter(object):
tensor.set_shape(shape)
+class TocoConverter(object):
+ """Convert a TensorFlow model into `output_format` using TOCO.
+
+ This class has been deprecated. Please use `lite.TFLiteConverter` instead.
+ """
+
+ @classmethod
+ @_deprecation.deprecated(None,
+ "Use `lite.TFLiteConverter.from_session` instead.")
+ def from_session(cls, sess, input_tensors, output_tensors):
+ """Creates a TocoConverter class from a TensorFlow Session."""
+ return TFLiteConverter.from_session(sess, input_tensors, output_tensors)
+
+ @classmethod
+ @_deprecation.deprecated(
+ None, "Use `lite.TFLiteConverter.from_frozen_graph` instead.")
+ def from_frozen_graph(cls,
+ graph_def_file,
+ input_arrays,
+ output_arrays,
+ input_shapes=None):
+ """Creates a TocoConverter class from a file containing a frozen graph."""
+ return TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays,
+ output_arrays, input_shapes)
+
+ @classmethod
+ @_deprecation.deprecated(
+ None, "Use `lite.TFLiteConverter.from_saved_model` instead.")
+ def from_saved_model(cls,
+ saved_model_dir,
+ input_arrays=None,
+ input_shapes=None,
+ output_arrays=None,
+ tag_set=None,
+ signature_key=None):
+ """Creates a TocoConverter class from a SavedModel."""
+ return TFLiteConverter.from_saved_model(saved_model_dir, input_arrays,
+ input_shapes, output_arrays,
+ tag_set, signature_key)
+
+ @classmethod
+ @_deprecation.deprecated(
+ None, "Use `lite.TFLiteConverter.from_keras_model_file` instead.")
+ def from_keras_model_file(cls,
+ model_file,
+ input_arrays=None,
+ input_shapes=None,
+ output_arrays=None):
+ """Creates a TocoConverter class from a tf.keras model file."""
+ return TFLiteConverter.from_keras_model_file(model_file, input_arrays,
+ input_shapes, output_arrays)
+
+
def _is_frozen_graph(sess):
"""Determines if the graph is frozen.
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index f112ed5cdd..d243a494f6 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -50,18 +50,18 @@ class FromConstructor(test_util.TensorFlowTestCase):
# `output_arrays` is not defined.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter(
+ lite.TFLiteConverter(
None, None, [], input_arrays_with_shape=[('input', [3, 9])])
self.assertEqual(message, str(error.exception))
# `input_arrays_with_shape` is not defined.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter(None, [], None, output_arrays=['output'])
+ lite.TFLiteConverter(None, [], None, output_arrays=['output'])
self.assertEqual(message, str(error.exception))
# Tests valid constructors using a dummy value for the GraphDef.
def testValidConstructor(self):
- converter = lite.TocoConverter(
+ converter = lite.TFLiteConverter(
None,
None,
None,
@@ -76,7 +76,7 @@ class FromConstructor(test_util.TensorFlowTestCase):
'The batch size cannot be set for this model. Please use '
'input_shapes parameter.', str(error.exception))
- converter = lite.TocoConverter(None, ['input_tensor'], ['output_tensor'])
+ converter = lite.TFLiteConverter(None, ['input_tensor'], ['output_tensor'])
self.assertTrue(converter._has_valid_tensors())
@@ -89,7 +89,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -121,7 +122,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(
+ converter = lite.TFLiteConverter.from_session(
sess, [in_tensor_1, in_tensor_2], [out_tensor])
converter.inference_type = lite_constants.QUANTIZED_UINT8
converter.quantized_input_stats = {
@@ -166,7 +167,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(
+ converter = lite.TFLiteConverter.from_session(
sess, [in_tensor_1, in_tensor_2], [out_tensor])
converter.inference_type = lite_constants.QUANTIZED_UINT8
converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev
@@ -182,7 +183,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Test invalid shape. None after 1st dimension.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
with self.assertRaises(ValueError) as error:
converter.convert()
self.assertEqual('Provide an input shape for input array \'Placeholder\'.',
@@ -195,7 +197,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Test invalid shape. None after 1st dimension.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
with self.assertRaises(ValueError) as error:
converter.convert()
self.assertEqual(
@@ -210,7 +213,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -242,7 +246,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess.run(_global_variables_initializer())
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -272,7 +277,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
converter.output_format = lite_constants.GRAPHVIZ_DOT
graphviz_output = converter.convert()
self.assertTrue(graphviz_output)
@@ -285,7 +291,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
graphviz_dir = self.get_temp_dir()
converter.dump_graphviz_dir = graphviz_dir
tflite_model = converter.convert()
@@ -299,7 +306,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(num_items_graphviz)
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
graphviz_dir = self.get_temp_dir()
converter.dump_graphviz_dir = graphviz_dir
converter.dump_graphviz_video = True
@@ -317,7 +325,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
converter.inference_input_type = lite_constants.QUANTIZED_UINT8
converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
tflite_model = converter.convert()
@@ -347,7 +356,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
converter.inference_type = lite_constants.QUANTIZED_UINT8
converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
converter.default_ranges_stats = (0, 6) # min, max
@@ -387,13 +397,13 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert float model.
- float_converter = lite.TocoConverter.from_session(sess, [in_tensor_1],
- [out_tensor])
+ float_converter = lite.TFLiteConverter.from_session(sess, [in_tensor_1],
+ [out_tensor])
float_tflite = float_converter.convert()
self.assertTrue(float_tflite)
# Convert quantized weights model.
- quantized_converter = lite.TocoConverter.from_session(
+ quantized_converter = lite.TFLiteConverter.from_session(
sess, [in_tensor_1], [out_tensor])
quantized_converter.post_training_quantize = True
quantized_tflite = quantized_converter.convert()
@@ -402,15 +412,16 @@ class FromSessionTest(test_util.TensorFlowTestCase):
# Ensure that the quantized weights tflite model is smaller.
self.assertTrue(len(quantized_tflite) < len(float_tflite))
- def testExtendedMode(self):
+ def testFlexMode(self):
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
out_tensor = in_tensor + in_tensor
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
- converter.converter_mode = lite.ConverterMode.TOCO_EXTENDED_ALL
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
+ converter.converter_mode = lite.ConverterMode.TOCO_FLEX_ALL
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -421,9 +432,25 @@ class FromSessionTest(test_util.TensorFlowTestCase):
interpreter.allocate_tensors()
self.assertIn(
'Regular TensorFlow ops are not supported by this interpreter. Make '
- 'sure you invoke the Eager delegate before inference.',
+ 'sure you invoke the Flex delegate before inference.',
str(error.exception))
+ def testFloatTocoConverter(self):
+ """Tests deprecated test TocoConverter."""
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensure the interpreter is able to load.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
@@ -439,8 +466,8 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
sess.close()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
- ['Placeholder'], ['add'])
+ converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
+ ['Placeholder'], ['add'])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -474,7 +501,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
sess.close()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_frozen_graph(
+ converter = lite.TFLiteConverter.from_frozen_graph(
graph_def_file, ['Placeholder'], ['add'],
input_shapes={'Placeholder': [1, 16, 16, 3]})
tflite_model = converter.convert()
@@ -503,8 +530,8 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Ensure the graph with variables cannot be converted.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
- ['add'])
+ lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
+ ['add'])
self.assertEqual('Please freeze the graph using freeze_graph.py.',
str(error.exception))
@@ -520,8 +547,8 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
sess.close()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
- ['Placeholder'], ['add'])
+ converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
+ ['Placeholder'], ['add'])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -545,8 +572,8 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
def testInvalidFileNotFound(self):
with self.assertRaises(IOError) as error:
- lite.TocoConverter.from_frozen_graph('invalid_file', ['Placeholder'],
- ['add'])
+ lite.TFLiteConverter.from_frozen_graph('invalid_file', ['Placeholder'],
+ ['add'])
self.assertEqual('File \'invalid_file\' does not exist.',
str(error.exception))
@@ -558,8 +585,8 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Attempts to convert the invalid model.
with self.assertRaises(IOError) as error:
- lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
- ['add'])
+ lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
+ ['add'])
self.assertEqual(
'Unable to parse input file \'{}\'.'.format(graph_def_file),
str(error.exception))
@@ -580,7 +607,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Tests the object detection model that cannot be loaded in TensorFlow.
self._initObjectDetectionArgs()
- converter = lite.TocoConverter.from_frozen_graph(
+ converter = lite.TFLiteConverter.from_frozen_graph(
self._graph_def_file, self._input_arrays, self._output_arrays,
self._input_shapes)
converter.allow_custom_ops = True
@@ -621,7 +648,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Missing `input_shapes`.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter.from_frozen_graph(
+ lite.TFLiteConverter.from_frozen_graph(
self._graph_def_file, self._input_arrays, self._output_arrays)
self.assertEqual('input_shapes must be defined for this model.',
str(error.exception))
@@ -632,7 +659,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# `input_shapes` does not contain the names in `input_arrays`.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter.from_frozen_graph(
+ lite.TFLiteConverter.from_frozen_graph(
self._graph_def_file,
self._input_arrays,
self._output_arrays,
@@ -641,6 +668,27 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
'input_shapes must contain a value for each item in input_array.',
str(error.exception))
+ def testFloatTocoConverter(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Write graph to file.
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+ write_graph(sess.graph_def, '', graph_def_file, False)
+ sess.close()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
+ ['Placeholder'], ['add'])
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensure the model is able to load.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
class FromSavedModelTest(test_util.TensorFlowTestCase):
@@ -663,7 +711,7 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_saved_model(saved_model_dir)
+ converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -693,7 +741,7 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
"""Test a SavedModel, with None in input tensor's shape."""
saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3])
- converter = lite.TocoConverter.from_saved_model(saved_model_dir)
+ converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -724,7 +772,7 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
"""Test a SavedModel ordering of input arrays."""
saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
- converter = lite.TocoConverter.from_saved_model(
+ converter = lite.TFLiteConverter.from_saved_model(
saved_model_dir, input_arrays=['inputB', 'inputA'])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -757,7 +805,7 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
# Check case where input shape is given.
- converter = lite.TocoConverter.from_saved_model(
+ converter = lite.TFLiteConverter.from_saved_model(
saved_model_dir,
input_arrays=['inputA'],
input_shapes={'inputA': [1, 16, 16, 3]})
@@ -766,12 +814,25 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
self.assertTrue(tflite_model)
# Check case where input shape is None.
- converter = lite.TocoConverter.from_saved_model(
+ converter = lite.TFLiteConverter.from_saved_model(
saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None})
tflite_model = converter.convert()
self.assertTrue(tflite_model)
+ def testSimpleModelTocoConverter(self):
+ """Test a SavedModel with deprecated TocoConverter."""
+ saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_saved_model(saved_model_dir)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensure the model is able to load.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
class FromKerasFile(test_util.TensorFlowTestCase):
@@ -805,7 +866,7 @@ class FromKerasFile(test_util.TensorFlowTestCase):
"""Test a Sequential tf.keras model with default inputs."""
keras_file = self._getSequentialModel()
- converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ converter = lite.TFLiteConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -845,13 +906,13 @@ class FromKerasFile(test_util.TensorFlowTestCase):
# Invalid input array raises error.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter.from_keras_model_file(
+ lite.TFLiteConverter.from_keras_model_file(
keras_file, input_arrays=['invalid-input'])
self.assertEqual("Invalid tensors 'invalid-input' were found.",
str(error.exception))
# Valid input array.
- converter = lite.TocoConverter.from_keras_model_file(
+ converter = lite.TFLiteConverter.from_keras_model_file(
keras_file, input_arrays=['dense_input'])
tflite_model = converter.convert()
os.remove(keras_file)
@@ -863,13 +924,13 @@ class FromKerasFile(test_util.TensorFlowTestCase):
# Passing in shape of invalid input array has no impact as long as all input
# arrays have a shape.
- converter = lite.TocoConverter.from_keras_model_file(
+ converter = lite.TFLiteConverter.from_keras_model_file(
keras_file, input_shapes={'invalid-input': [2, 3]})
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Passing in shape of valid input array.
- converter = lite.TocoConverter.from_keras_model_file(
+ converter = lite.TFLiteConverter.from_keras_model_file(
keras_file, input_shapes={'dense_input': [2, 3]})
tflite_model = converter.convert()
os.remove(keras_file)
@@ -890,13 +951,13 @@ class FromKerasFile(test_util.TensorFlowTestCase):
# Invalid output array raises error.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter.from_keras_model_file(
+ lite.TFLiteConverter.from_keras_model_file(
keras_file, output_arrays=['invalid-output'])
self.assertEqual("Invalid tensors 'invalid-output' were found.",
str(error.exception))
# Valid output array.
- converter = lite.TocoConverter.from_keras_model_file(
+ converter = lite.TFLiteConverter.from_keras_model_file(
keras_file, output_arrays=['time_distributed/Reshape_1'])
tflite_model = converter.convert()
os.remove(keras_file)
@@ -926,7 +987,7 @@ class FromKerasFile(test_util.TensorFlowTestCase):
os.close(fd)
# Convert to TFLite model.
- converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ converter = lite.TFLiteConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -991,7 +1052,7 @@ class FromKerasFile(test_util.TensorFlowTestCase):
os.close(fd)
# Convert to TFLite model.
- converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ converter = lite.TFLiteConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -1052,7 +1113,7 @@ class FromKerasFile(test_util.TensorFlowTestCase):
os.close(fd)
# Convert to TFLite model.
- converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ converter = lite.TFLiteConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -1086,6 +1147,18 @@ class FromKerasFile(test_util.TensorFlowTestCase):
np.testing.assert_almost_equal(tflite_result, keras_result, 5)
os.remove(keras_file)
+ def testSequentialModelTocoConverter(self):
+ """Test a Sequential tf.keras model with deprecated TocoConverter."""
+ keras_file = self._getSequentialModel()
+
+ converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensure the model is able to load.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index c0ff7f37f9..d6d9052a4e 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -40,13 +40,13 @@ def _parse_set(values):
def _get_toco_converter(flags):
- """Makes a TocoConverter object based on the flags provided.
+ """Makes a TFLiteConverter object based on the flags provided.
Args:
flags: argparse.Namespace object containing TFLite flags.
Returns:
- TocoConverter object.
+ TFLiteConverter object.
Raises:
ValueError: Invalid flags.
@@ -68,17 +68,17 @@ def _get_toco_converter(flags):
"output_arrays": output_arrays
}
- # Create TocoConverter.
+ # Create TFLiteConverter.
if flags.graph_def_file:
- converter_fn = lite.TocoConverter.from_frozen_graph
+ converter_fn = lite.TFLiteConverter.from_frozen_graph
converter_kwargs["graph_def_file"] = flags.graph_def_file
elif flags.saved_model_dir:
- converter_fn = lite.TocoConverter.from_saved_model
+ converter_fn = lite.TFLiteConverter.from_saved_model
converter_kwargs["saved_model_dir"] = flags.saved_model_dir
converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set)
converter_kwargs["signature_key"] = flags.saved_model_signature_key
elif flags.keras_model_file:
- converter_fn = lite.TocoConverter.from_keras_model_file
+ converter_fn = lite.TFLiteConverter.from_keras_model_file
converter_kwargs["model_file"] = flags.keras_model_file
else:
raise ValueError("--graph_def_file, --saved_model_dir, or "
diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD
index 55bf2c48b9..d892466c7a 100644
--- a/tensorflow/contrib/lite/schema/BUILD
+++ b/tensorflow/contrib/lite/schema/BUILD
@@ -25,14 +25,18 @@ py_binary(
],
)
+# TODO(wvo): re-enable this test once latest FlatBuffers has landed.
+
py_test(
name = "upgrade_schema_test",
size = "small",
srcs = ["upgrade_schema_test.py"],
srcs_version = "PY2AND3",
tags = [
+ "manual",
"no_oss",
"no_pip",
+ "notap",
],
deps = [
":upgrade_schema",
diff --git a/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc b/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc
index 11057203a8..22b4616ccb 100644
--- a/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc
+++ b/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include <fstream>
#include <gtest/gtest.h>
-#include "flatbuffers/flatc.h" // flatbuffers
+#include "flatbuffers/flatc.h" // TF:flatbuffers
#include "tensorflow/core/platform/platform.h"
#ifdef PLATFORM_GOOGLE
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index f0db22d581..3da3188c3a 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -174,6 +174,8 @@ enum BuiltinOperator : byte {
FLOOR_DIV = 90,
REDUCE_ANY = 91,
SQUARE = 92,
+ ZEROS_LIKE = 93,
+ FILL = 94,
}
// Options for the builtin operators.
@@ -244,6 +246,8 @@ union BuiltinOptions {
UnpackOptions,
FloorDivOptions,
SquareOptions,
+ ZerosLikeOptions,
+ FillOptions,
}
enum Padding : byte { SAME, VALID }
@@ -588,6 +592,12 @@ table FloorDivOptions {
table SquareOptions {
}
+table ZerosLikeOptions {
+}
+
+table FillOptions {
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 8c086a5e67..23ac8484de 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -229,6 +229,12 @@ struct FloorDivOptionsT;
struct SquareOptions;
struct SquareOptionsT;
+struct ZerosLikeOptions;
+struct ZerosLikeOptionsT;
+
+struct FillOptions;
+struct FillOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -258,8 +264,8 @@ enum TensorType {
TensorType_MAX = TensorType_COMPLEX64
};
-inline TensorType (&EnumValuesTensorType())[9] {
- static TensorType values[] = {
+inline const TensorType (&EnumValuesTensorType())[9] {
+ static const TensorType values[] = {
TensorType_FLOAT32,
TensorType_FLOAT16,
TensorType_INT32,
@@ -273,8 +279,8 @@ inline TensorType (&EnumValuesTensorType())[9] {
return values;
}
-inline const char **EnumNamesTensorType() {
- static const char *names[] = {
+inline const char * const *EnumNamesTensorType() {
+ static const char * const names[] = {
"FLOAT32",
"FLOAT16",
"INT32",
@@ -387,12 +393,14 @@ enum BuiltinOperator {
BuiltinOperator_FLOOR_DIV = 90,
BuiltinOperator_REDUCE_ANY = 91,
BuiltinOperator_SQUARE = 92,
+ BuiltinOperator_ZEROS_LIKE = 93,
+ BuiltinOperator_FILL = 94,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_SQUARE
+ BuiltinOperator_MAX = BuiltinOperator_FILL
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[92] {
- static BuiltinOperator values[] = {
+inline const BuiltinOperator (&EnumValuesBuiltinOperator())[94] {
+ static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
BuiltinOperator_CONCATENATION,
@@ -484,13 +492,15 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[92] {
BuiltinOperator_REDUCE_MIN,
BuiltinOperator_FLOOR_DIV,
BuiltinOperator_REDUCE_ANY,
- BuiltinOperator_SQUARE
+ BuiltinOperator_SQUARE,
+ BuiltinOperator_ZEROS_LIKE,
+ BuiltinOperator_FILL
};
return values;
}
-inline const char **EnumNamesBuiltinOperator() {
- static const char *names[] = {
+inline const char * const *EnumNamesBuiltinOperator() {
+ static const char * const names[] = {
"ADD",
"AVERAGE_POOL_2D",
"CONCATENATION",
@@ -584,6 +594,8 @@ inline const char **EnumNamesBuiltinOperator() {
"FLOOR_DIV",
"REDUCE_ANY",
"SQUARE",
+ "ZEROS_LIKE",
+ "FILL",
nullptr
};
return names;
@@ -662,12 +674,14 @@ enum BuiltinOptions {
BuiltinOptions_UnpackOptions = 64,
BuiltinOptions_FloorDivOptions = 65,
BuiltinOptions_SquareOptions = 66,
+ BuiltinOptions_ZerosLikeOptions = 67,
+ BuiltinOptions_FillOptions = 68,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_SquareOptions
+ BuiltinOptions_MAX = BuiltinOptions_FillOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[67] {
- static BuiltinOptions values[] = {
+inline const BuiltinOptions (&EnumValuesBuiltinOptions())[69] {
+ static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
BuiltinOptions_DepthwiseConv2DOptions,
@@ -734,13 +748,15 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[67] {
BuiltinOptions_LogicalNotOptions,
BuiltinOptions_UnpackOptions,
BuiltinOptions_FloorDivOptions,
- BuiltinOptions_SquareOptions
+ BuiltinOptions_SquareOptions,
+ BuiltinOptions_ZerosLikeOptions,
+ BuiltinOptions_FillOptions
};
return values;
}
-inline const char **EnumNamesBuiltinOptions() {
- static const char *names[] = {
+inline const char * const *EnumNamesBuiltinOptions() {
+ static const char * const names[] = {
"NONE",
"Conv2DOptions",
"DepthwiseConv2DOptions",
@@ -808,6 +824,8 @@ inline const char **EnumNamesBuiltinOptions() {
"UnpackOptions",
"FloorDivOptions",
"SquareOptions",
+ "ZerosLikeOptions",
+ "FillOptions",
nullptr
};
return names;
@@ -1086,6 +1104,14 @@ template<> struct BuiltinOptionsTraits<SquareOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_SquareOptions;
};
+template<> struct BuiltinOptionsTraits<ZerosLikeOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_ZerosLikeOptions;
+};
+
+template<> struct BuiltinOptionsTraits<FillOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_FillOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1645,6 +1671,22 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_SquareOptions ?
reinterpret_cast<const SquareOptionsT *>(value) : nullptr;
}
+ ZerosLikeOptionsT *AsZerosLikeOptions() {
+ return type == BuiltinOptions_ZerosLikeOptions ?
+ reinterpret_cast<ZerosLikeOptionsT *>(value) : nullptr;
+ }
+ const ZerosLikeOptionsT *AsZerosLikeOptions() const {
+ return type == BuiltinOptions_ZerosLikeOptions ?
+ reinterpret_cast<const ZerosLikeOptionsT *>(value) : nullptr;
+ }
+ FillOptionsT *AsFillOptions() {
+ return type == BuiltinOptions_FillOptions ?
+ reinterpret_cast<FillOptionsT *>(value) : nullptr;
+ }
+ const FillOptionsT *AsFillOptions() const {
+ return type == BuiltinOptions_FillOptions ?
+ reinterpret_cast<const FillOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -1657,16 +1699,16 @@ enum Padding {
Padding_MAX = Padding_VALID
};
-inline Padding (&EnumValuesPadding())[2] {
- static Padding values[] = {
+inline const Padding (&EnumValuesPadding())[2] {
+ static const Padding values[] = {
Padding_SAME,
Padding_VALID
};
return values;
}
-inline const char **EnumNamesPadding() {
- static const char *names[] = {
+inline const char * const *EnumNamesPadding() {
+ static const char * const names[] = {
"SAME",
"VALID",
nullptr
@@ -1690,8 +1732,8 @@ enum ActivationFunctionType {
ActivationFunctionType_MAX = ActivationFunctionType_SIGN_BIT
};
-inline ActivationFunctionType (&EnumValuesActivationFunctionType())[6] {
- static ActivationFunctionType values[] = {
+inline const ActivationFunctionType (&EnumValuesActivationFunctionType())[6] {
+ static const ActivationFunctionType values[] = {
ActivationFunctionType_NONE,
ActivationFunctionType_RELU,
ActivationFunctionType_RELU_N1_TO_1,
@@ -1702,8 +1744,8 @@ inline ActivationFunctionType (&EnumValuesActivationFunctionType())[6] {
return values;
}
-inline const char **EnumNamesActivationFunctionType() {
- static const char *names[] = {
+inline const char * const *EnumNamesActivationFunctionType() {
+ static const char * const names[] = {
"NONE",
"RELU",
"RELU_N1_TO_1",
@@ -1728,8 +1770,8 @@ enum LSHProjectionType {
LSHProjectionType_MAX = LSHProjectionType_DENSE
};
-inline LSHProjectionType (&EnumValuesLSHProjectionType())[3] {
- static LSHProjectionType values[] = {
+inline const LSHProjectionType (&EnumValuesLSHProjectionType())[3] {
+ static const LSHProjectionType values[] = {
LSHProjectionType_UNKNOWN,
LSHProjectionType_SPARSE,
LSHProjectionType_DENSE
@@ -1737,8 +1779,8 @@ inline LSHProjectionType (&EnumValuesLSHProjectionType())[3] {
return values;
}
-inline const char **EnumNamesLSHProjectionType() {
- static const char *names[] = {
+inline const char * const *EnumNamesLSHProjectionType() {
+ static const char * const names[] = {
"UNKNOWN",
"SPARSE",
"DENSE",
@@ -1759,16 +1801,16 @@ enum FullyConnectedOptionsWeightsFormat {
FullyConnectedOptionsWeightsFormat_MAX = FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8
};
-inline FullyConnectedOptionsWeightsFormat (&EnumValuesFullyConnectedOptionsWeightsFormat())[2] {
- static FullyConnectedOptionsWeightsFormat values[] = {
+inline const FullyConnectedOptionsWeightsFormat (&EnumValuesFullyConnectedOptionsWeightsFormat())[2] {
+ static const FullyConnectedOptionsWeightsFormat values[] = {
FullyConnectedOptionsWeightsFormat_DEFAULT,
FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8
};
return values;
}
-inline const char **EnumNamesFullyConnectedOptionsWeightsFormat() {
- static const char *names[] = {
+inline const char * const *EnumNamesFullyConnectedOptionsWeightsFormat() {
+ static const char * const names[] = {
"DEFAULT",
"SHUFFLED4x16INT8",
nullptr
@@ -1788,16 +1830,16 @@ enum LSTMKernelType {
LSTMKernelType_MAX = LSTMKernelType_BASIC
};
-inline LSTMKernelType (&EnumValuesLSTMKernelType())[2] {
- static LSTMKernelType values[] = {
+inline const LSTMKernelType (&EnumValuesLSTMKernelType())[2] {
+ static const LSTMKernelType values[] = {
LSTMKernelType_FULL,
LSTMKernelType_BASIC
};
return values;
}
-inline const char **EnumNamesLSTMKernelType() {
- static const char *names[] = {
+inline const char * const *EnumNamesLSTMKernelType() {
+ static const char * const names[] = {
"FULL",
"BASIC",
nullptr
@@ -1818,8 +1860,8 @@ enum CombinerType {
CombinerType_MAX = CombinerType_SQRTN
};
-inline CombinerType (&EnumValuesCombinerType())[3] {
- static CombinerType values[] = {
+inline const CombinerType (&EnumValuesCombinerType())[3] {
+ static const CombinerType values[] = {
CombinerType_SUM,
CombinerType_MEAN,
CombinerType_SQRTN
@@ -1827,8 +1869,8 @@ inline CombinerType (&EnumValuesCombinerType())[3] {
return values;
}
-inline const char **EnumNamesCombinerType() {
- static const char *names[] = {
+inline const char * const *EnumNamesCombinerType() {
+ static const char * const names[] = {
"SUM",
"MEAN",
"SQRTN",
@@ -1848,15 +1890,15 @@ enum CustomOptionsFormat {
CustomOptionsFormat_MAX = CustomOptionsFormat_FLEXBUFFERS
};
-inline CustomOptionsFormat (&EnumValuesCustomOptionsFormat())[1] {
- static CustomOptionsFormat values[] = {
+inline const CustomOptionsFormat (&EnumValuesCustomOptionsFormat())[1] {
+ static const CustomOptionsFormat values[] = {
CustomOptionsFormat_FLEXBUFFERS
};
return values;
}
-inline const char **EnumNamesCustomOptionsFormat() {
- static const char *names[] = {
+inline const char * const *EnumNamesCustomOptionsFormat() {
+ static const char * const names[] = {
"FLEXBUFFERS",
nullptr
};
@@ -1901,13 +1943,13 @@ struct QuantizationParameters FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_MIN) &&
- verifier.Verify(min()) &&
+ verifier.VerifyVector(min()) &&
VerifyOffset(verifier, VT_MAX) &&
- verifier.Verify(max()) &&
+ verifier.VerifyVector(max()) &&
VerifyOffset(verifier, VT_SCALE) &&
- verifier.Verify(scale()) &&
+ verifier.VerifyVector(scale()) &&
VerifyOffset(verifier, VT_ZERO_POINT) &&
- verifier.Verify(zero_point()) &&
+ verifier.VerifyVector(zero_point()) &&
verifier.EndTable();
}
QuantizationParametersT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -2018,11 +2060,11 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_SHAPE) &&
- verifier.Verify(shape()) &&
+ verifier.VerifyVector(shape()) &&
VerifyField<int8_t>(verifier, VT_TYPE) &&
VerifyField<uint32_t>(verifier, VT_BUFFER) &&
VerifyOffset(verifier, VT_NAME) &&
- verifier.Verify(name()) &&
+ verifier.VerifyString(name()) &&
VerifyOffset(verifier, VT_QUANTIZATION) &&
verifier.VerifyTable(quantization()) &&
VerifyField<uint8_t>(verifier, VT_IS_VARIABLE) &&
@@ -2488,9 +2530,9 @@ struct ConcatEmbeddingsOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Ta
return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, VT_NUM_CHANNELS) &&
VerifyOffset(verifier, VT_NUM_COLUMNS_PER_CHANNEL) &&
- verifier.Verify(num_columns_per_channel()) &&
+ verifier.VerifyVector(num_columns_per_channel()) &&
VerifyOffset(verifier, VT_EMBEDDING_DIM_PER_CHANNEL) &&
- verifier.Verify(embedding_dim_per_channel()) &&
+ verifier.VerifyVector(embedding_dim_per_channel()) &&
verifier.EndTable();
}
ConcatEmbeddingsOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -3588,7 +3630,7 @@ struct ReshapeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_NEW_SHAPE) &&
- verifier.Verify(new_shape()) &&
+ verifier.VerifyVector(new_shape()) &&
verifier.EndTable();
}
ReshapeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -4252,7 +4294,7 @@ struct SqueezeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_SQUEEZE_DIMS) &&
- verifier.Verify(squeeze_dims()) &&
+ verifier.VerifyVector(squeeze_dims()) &&
verifier.EndTable();
}
SqueezeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -5888,6 +5930,86 @@ inline flatbuffers::Offset<SquareOptions> CreateSquareOptions(
flatbuffers::Offset<SquareOptions> CreateSquareOptions(flatbuffers::FlatBufferBuilder &_fbb, const SquareOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct ZerosLikeOptionsT : public flatbuffers::NativeTable {
+ typedef ZerosLikeOptions TableType;
+ ZerosLikeOptionsT() {
+ }
+};
+
+struct ZerosLikeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ZerosLikeOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ ZerosLikeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ZerosLikeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<ZerosLikeOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ZerosLikeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ZerosLikeOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit ZerosLikeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ZerosLikeOptionsBuilder &operator=(const ZerosLikeOptionsBuilder &);
+ flatbuffers::Offset<ZerosLikeOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ZerosLikeOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ZerosLikeOptions> CreateZerosLikeOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ ZerosLikeOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<ZerosLikeOptions> CreateZerosLikeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ZerosLikeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct FillOptionsT : public flatbuffers::NativeTable {
+ typedef FillOptions TableType;
+ FillOptionsT() {
+ }
+};
+
+struct FillOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef FillOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ FillOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(FillOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<FillOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const FillOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct FillOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit FillOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ FillOptionsBuilder &operator=(const FillOptionsBuilder &);
+ flatbuffers::Offset<FillOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<FillOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<FillOptions> CreateFillOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ FillOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<FillOptions> CreateFillOptions(flatbuffers::FlatBufferBuilder &_fbb, const FillOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -5919,7 +6041,7 @@ struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_BUILTIN_CODE) &&
VerifyOffset(verifier, VT_CUSTOM_CODE) &&
- verifier.Verify(custom_code()) &&
+ verifier.VerifyString(custom_code()) &&
VerifyField<int32_t>(verifier, VT_VERSION) &&
verifier.EndTable();
}
@@ -6219,6 +6341,12 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const SquareOptions *builtin_options_as_SquareOptions() const {
return builtin_options_type() == BuiltinOptions_SquareOptions ? static_cast<const SquareOptions *>(builtin_options()) : nullptr;
}
+ const ZerosLikeOptions *builtin_options_as_ZerosLikeOptions() const {
+ return builtin_options_type() == BuiltinOptions_ZerosLikeOptions ? static_cast<const ZerosLikeOptions *>(builtin_options()) : nullptr;
+ }
+ const FillOptions *builtin_options_as_FillOptions() const {
+ return builtin_options_type() == BuiltinOptions_FillOptions ? static_cast<const FillOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -6232,17 +6360,17 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
return VerifyTableStart(verifier) &&
VerifyField<uint32_t>(verifier, VT_OPCODE_INDEX) &&
VerifyOffset(verifier, VT_INPUTS) &&
- verifier.Verify(inputs()) &&
+ verifier.VerifyVector(inputs()) &&
VerifyOffset(verifier, VT_OUTPUTS) &&
- verifier.Verify(outputs()) &&
+ verifier.VerifyVector(outputs()) &&
VerifyField<uint8_t>(verifier, VT_BUILTIN_OPTIONS_TYPE) &&
VerifyOffset(verifier, VT_BUILTIN_OPTIONS) &&
VerifyBuiltinOptions(verifier, builtin_options(), builtin_options_type()) &&
VerifyOffset(verifier, VT_CUSTOM_OPTIONS) &&
- verifier.Verify(custom_options()) &&
+ verifier.VerifyVector(custom_options()) &&
VerifyField<int8_t>(verifier, VT_CUSTOM_OPTIONS_FORMAT) &&
VerifyOffset(verifier, VT_MUTATING_VARIABLE_INPUTS) &&
- verifier.Verify(mutating_variable_inputs()) &&
+ verifier.VerifyVector(mutating_variable_inputs()) &&
verifier.EndTable();
}
OperatorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -6514,6 +6642,14 @@ template<> inline const SquareOptions *Operator::builtin_options_as<SquareOption
return builtin_options_as_SquareOptions();
}
+template<> inline const ZerosLikeOptions *Operator::builtin_options_as<ZerosLikeOptions>() const {
+ return builtin_options_as_ZerosLikeOptions();
+}
+
+template<> inline const FillOptions *Operator::builtin_options_as<FillOptions>() const {
+ return builtin_options_as_FillOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -6637,17 +6773,17 @@ struct SubGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_TENSORS) &&
- verifier.Verify(tensors()) &&
+ verifier.VerifyVector(tensors()) &&
verifier.VerifyVectorOfTables(tensors()) &&
VerifyOffset(verifier, VT_INPUTS) &&
- verifier.Verify(inputs()) &&
+ verifier.VerifyVector(inputs()) &&
VerifyOffset(verifier, VT_OUTPUTS) &&
- verifier.Verify(outputs()) &&
+ verifier.VerifyVector(outputs()) &&
VerifyOffset(verifier, VT_OPERATORS) &&
- verifier.Verify(operators()) &&
+ verifier.VerifyVector(operators()) &&
verifier.VerifyVectorOfTables(operators()) &&
VerifyOffset(verifier, VT_NAME) &&
- verifier.Verify(name()) &&
+ verifier.VerifyString(name()) &&
verifier.EndTable();
}
SubGraphT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -6737,7 +6873,7 @@ struct Buffer FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_DATA) &&
- verifier.Verify(data()) &&
+ verifier.VerifyVector(data()) &&
verifier.EndTable();
}
BufferT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -6826,18 +6962,18 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
return VerifyTableStart(verifier) &&
VerifyField<uint32_t>(verifier, VT_VERSION) &&
VerifyOffset(verifier, VT_OPERATOR_CODES) &&
- verifier.Verify(operator_codes()) &&
+ verifier.VerifyVector(operator_codes()) &&
verifier.VerifyVectorOfTables(operator_codes()) &&
VerifyOffset(verifier, VT_SUBGRAPHS) &&
- verifier.Verify(subgraphs()) &&
+ verifier.VerifyVector(subgraphs()) &&
verifier.VerifyVectorOfTables(subgraphs()) &&
VerifyOffset(verifier, VT_DESCRIPTION) &&
- verifier.Verify(description()) &&
+ verifier.VerifyString(description()) &&
VerifyOffset(verifier, VT_BUFFERS) &&
- verifier.Verify(buffers()) &&
+ verifier.VerifyVector(buffers()) &&
verifier.VerifyVectorOfTables(buffers()) &&
VerifyOffset(verifier, VT_METADATA_BUFFER) &&
- verifier.Verify(metadata_buffer()) &&
+ verifier.VerifyVector(metadata_buffer()) &&
verifier.EndTable();
}
ModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -8782,6 +8918,52 @@ inline flatbuffers::Offset<SquareOptions> CreateSquareOptions(flatbuffers::FlatB
_fbb);
}
+inline ZerosLikeOptionsT *ZerosLikeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new ZerosLikeOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void ZerosLikeOptions::UnPackTo(ZerosLikeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<ZerosLikeOptions> ZerosLikeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ZerosLikeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateZerosLikeOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<ZerosLikeOptions> CreateZerosLikeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ZerosLikeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ZerosLikeOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateZerosLikeOptions(
+ _fbb);
+}
+
+inline FillOptionsT *FillOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new FillOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void FillOptions::UnPackTo(FillOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<FillOptions> FillOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FillOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateFillOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<FillOptions> CreateFillOptions(flatbuffers::FlatBufferBuilder &_fbb, const FillOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FillOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateFillOptions(
+ _fbb);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -9235,6 +9417,14 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const SquareOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_ZerosLikeOptions: {
+ auto ptr = reinterpret_cast<const ZerosLikeOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_FillOptions: {
+ auto ptr = reinterpret_cast<const FillOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -9517,6 +9707,14 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const SquareOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_ZerosLikeOptions: {
+ auto ptr = reinterpret_cast<const ZerosLikeOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_FillOptions: {
+ auto ptr = reinterpret_cast<const FillOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -9787,6 +9985,14 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const SquareOptionsT *>(value);
return CreateSquareOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_ZerosLikeOptions: {
+ auto ptr = reinterpret_cast<const ZerosLikeOptionsT *>(value);
+ return CreateZerosLikeOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_FillOptions: {
+ auto ptr = reinterpret_cast<const FillOptionsT *>(value);
+ return CreateFillOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -10057,6 +10263,14 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new SquareOptionsT(*reinterpret_cast<SquareOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_ZerosLikeOptions: {
+ value = new ZerosLikeOptionsT(*reinterpret_cast<ZerosLikeOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_FillOptions: {
+ value = new FillOptionsT(*reinterpret_cast<FillOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -10394,6 +10608,16 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_ZerosLikeOptions: {
+ auto ptr = reinterpret_cast<ZerosLikeOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_FillOptions: {
+ auto ptr = reinterpret_cast<FillOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
@@ -10404,6 +10628,10 @@ inline const tflite::Model *GetModel(const void *buf) {
return flatbuffers::GetRoot<tflite::Model>(buf);
}
+inline const tflite::Model *GetSizePrefixedModel(const void *buf) {
+ return flatbuffers::GetSizePrefixedRoot<tflite::Model>(buf);
+}
+
inline const char *ModelIdentifier() {
return "TFL3";
}
@@ -10418,6 +10646,11 @@ inline bool VerifyModelBuffer(
return verifier.VerifyBuffer<tflite::Model>(ModelIdentifier());
}
+inline bool VerifySizePrefixedModelBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifySizePrefixedBuffer<tflite::Model>(ModelIdentifier());
+}
+
inline const char *ModelExtension() {
return "tflite";
}
@@ -10428,6 +10661,12 @@ inline void FinishModelBuffer(
fbb.Finish(root, ModelIdentifier());
}
+inline void FinishSizePrefixedModelBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<tflite::Model> root) {
+ fbb.FinishSizePrefixed(root, ModelIdentifier());
+}
+
inline std::unique_ptr<ModelT> UnPackModel(
const void *buf,
const flatbuffers::resolver_function_t *res = nullptr) {
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index a4736bfee9..f0bfec2338 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -13,6 +13,7 @@ load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite"
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
+ "py_test",
)
[gen_zip_test(
@@ -163,7 +164,7 @@ cc_library(
":test_runner",
"//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
- "//tensorflow/contrib/lite/delegates/eager:delegate",
+ "//tensorflow/contrib/lite/delegates/flex:delegate",
"//tensorflow/contrib/lite/kernels:builtin_ops",
],
)
@@ -362,4 +363,32 @@ cc_binary(
],
)
+py_binary(
+ name = "model_coverage_lib",
+ srcs = ["//tensorflow/contrib/lite/testing:model_coverage/model_coverage_lib.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ visibility = ["//tensorflow/contrib/lite:__subpackages__"],
+ deps = [
+ "//tensorflow/contrib/lite/python:lite",
+ "//tensorflow/python:platform",
+ ],
+)
+
+py_test(
+ name = "model_coverage_lib_test",
+ srcs = ["//tensorflow/contrib/lite/testing:model_coverage/model_coverage_lib_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ "notap",
+ ],
+ deps = [
+ ":model_coverage_lib",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
tflite_portable_test_suite()
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 3754b58b23..18036fac6f 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -81,9 +81,9 @@ parser.add_argument(
action="store_true",
help="Include intermediate graphdefs in the output zip files.")
parser.add_argument(
- "--run_with_extended",
+ "--run_with_flex",
action="store_true",
- help="Whether the TFLite Extended converter is being used.")
+ help="Whether the TFLite Flex converter is being used.")
RANDOM_SEED = 342
TEST_INPUT_DEPTH = 3
@@ -339,11 +339,11 @@ def toco_convert(graph_def_str, input_tensors, output_tensors,
graphdef_file.flush()
# TODO(aselle): Switch this to subprocess at some point.
- if "pb2lite" in bin_path and FLAGS.run_with_extended:
+ if "pb2lite" in bin_path and FLAGS.run_with_flex:
opts = ("--input_arrays={0} --output_arrays={1}".format(
",".join(input_arrays), ",".join(output_tensors)))
- elif FLAGS.run_with_extended:
- opts += " --allow_eager_ops --force_eager_ops"
+ elif FLAGS.run_with_flex:
+ opts += " --allow_flex_ops --force_flex_ops"
cmd = ("%s --input_file=%s --output_file=%s %s > %s 2>&1" %
(bin_path, graphdef_file.name, output_file.name, opts,
stdout_file.name))
@@ -2834,6 +2834,31 @@ def make_neg_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_zeros_like_tests(zip_path):
+ """Make a set of tests to do zeros_like."""
+
+ test_parameters = [{
+ "input_dtype": [tf.float32, tf.int32, tf.int64],
+ "input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
+ }]
+
+ def build_graph(parameters):
+ """Build the zeros_like op testing graph."""
+ input_tensor = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input",
+ shape=parameters["input_shape"])
+ out = tf.zeros_like(input_tensor)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ values = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape"])
+ return [values], sess.run(outputs, feed_dict=dict(zip(inputs, [values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def _make_elementwise_tests(op):
"""Make a set of tests to do element-wise operations."""
@@ -3308,7 +3333,7 @@ def main(unused_args):
# list of valid conversion modes is defined in
# generated_test_conversion_modes() in build_def.bzl.
test_function = ("make_%s_tests" % (out.replace(".zip", "").replace(
- "pb2lite", "").replace("toco-extended", "").rstrip("_")))
+ "pb2lite", "").replace("toco-flex", "").rstrip("_")))
if test_function not in globals():
raise RuntimeError("Can't find a test function to create %r. Tried %r" %
(out, test_function))
diff --git a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
new file mode 100644
index 0000000000..5ca57d083d
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
@@ -0,0 +1,249 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions to test TFLite models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.lite.python import convert_saved_model as _convert_saved_model
+from tensorflow.contrib.lite.python import lite as _lite
+from tensorflow.core.framework import graph_pb2 as _graph_pb2
+from tensorflow.python import keras as _keras
+from tensorflow.python.client import session as _session
+from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
+from tensorflow.python.lib.io import file_io as _file_io
+from tensorflow.python.saved_model import signature_constants as _signature_constants
+from tensorflow.python.saved_model import tag_constants as _tag_constants
+
+
+def _convert(converter, **kwargs):
+ """Converts the model.
+
+ Args:
+ converter: TocoConverter object.
+ **kwargs: Additional arguments to be passed into the converter. Supported
+ flags are {"converter_mode", "post_training_quant"}.
+
+ Returns:
+ The converted TFLite model in serialized format.
+ """
+ if "converter_mode" in kwargs:
+ converter.converter_mode = kwargs["converter_mode"]
+ if "post_training_quantize" in kwargs:
+ converter.post_training_quantize = kwargs["post_training_quantize"]
+ return converter.convert()
+
+
+def _generate_random_input_data(tflite_model, seed=None):
+ """Generates input data based on the input tensors in the TFLite model.
+
+ Args:
+ tflite_model: Serialized TensorFlow Lite model.
+ seed: Integer seed for the random generator. (default None)
+
+ Returns:
+ List of np.ndarray.
+ """
+ interpreter = _lite.Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+ input_details = interpreter.get_input_details()
+
+ if seed:
+ np.random.seed(seed=seed)
+ return [
+ np.array(
+ np.random.random_sample(input_tensor["shape"]),
+ dtype=input_tensor["dtype"]) for input_tensor in input_details
+ ]
+
+
+def _evaluate_tflite_model(tflite_model, input_data):
+ """Returns evaluation of input data on TFLite model.
+
+ Args:
+ tflite_model: Serialized TensorFlow Lite model.
+ input_data: List of np.ndarray.
+
+ Returns:
+ List of np.ndarray.
+ """
+ interpreter = _lite.Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ output_details = interpreter.get_output_details()
+
+ for input_tensor, tensor_data in zip(input_details, input_data):
+ interpreter.set_tensor(input_tensor["index"], tensor_data)
+
+ interpreter.invoke()
+ output_data = [
+ interpreter.get_tensor(output_tensor["index"])
+ for output_tensor in output_details
+ ]
+ return output_data
+
+
+def evaluate_frozen_graph(filename, input_arrays, output_arrays):
+ """Returns a function that evaluates the frozen graph on input data.
+
+ Args:
+ filename: Full filepath of file containing frozen GraphDef.
+ input_arrays: List of input tensors to freeze graph with.
+ output_arrays: List of output tensors to freeze graph with.
+
+ Returns:
+ Lambda function ([np.ndarray data] : [np.ndarray result]).
+ """
+ with _session.Session().as_default() as sess:
+ with _file_io.FileIO(filename, "rb") as f:
+ file_content = f.read()
+
+ graph_def = _graph_pb2.GraphDef()
+ graph_def.ParseFromString(file_content)
+ _import_graph_def(graph_def, name="")
+
+ inputs = _convert_saved_model.get_tensors_from_tensor_names(
+ sess.graph, input_arrays)
+ outputs = _convert_saved_model.get_tensors_from_tensor_names(
+ sess.graph, output_arrays)
+
+ return lambda input_data: sess.run(outputs, dict(zip(inputs, input_data)))
+
+
+def evaluate_saved_model(directory, tag_set, signature_key):
+ """Returns a function that evaluates the SavedModel on input data.
+
+ Args:
+ directory: SavedModel directory to convert.
+ tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
+ analyze. All tags in the tag set must be present.
+ signature_key: Key identifying SignatureDef containing inputs and outputs.
+
+ Returns:
+ Lambda function ([np.ndarray data] : [np.ndarray result]).
+ """
+ with _session.Session().as_default() as sess:
+ if tag_set is None:
+ tag_set = set([_tag_constants.SERVING])
+ if signature_key is None:
+ signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+
+ meta_graph = _convert_saved_model.get_meta_graph_def(directory, tag_set)
+ signature_def = _convert_saved_model.get_signature_def(
+ meta_graph, signature_key)
+ inputs, outputs = _convert_saved_model.get_inputs_outputs(signature_def)
+
+ return lambda input_data: sess.run(outputs, dict(zip(inputs, input_data)))
+
+
+def evaluate_keras_model(filename):
+ """Returns a function that evaluates the tf.keras model on input data.
+
+ Args:
+ filename: Full filepath of HDF5 file containing the tf.keras model.
+
+ Returns:
+ Lambda function ([np.ndarray data] : [np.ndarray result]).
+ """
+ keras_model = _keras.models.load_model(filename)
+ return lambda input_data: [keras_model.predict(input_data)]
+
+
+# TODO(nupurgarg): Make this function a parameter to test_frozen_graph (and
+# related functions) in order to make it easy to use different data generators.
+def compare_models_random_data(tflite_model, tf_eval_func, tolerance=5):
+ """Compares TensorFlow and TFLite models with random data.
+
+ Args:
+ tflite_model: Serialized TensorFlow Lite model.
+ tf_eval_func: Lambda function that takes in input data and outputs the
+ results of the TensorFlow model ([np.ndarray data] : [np.ndarray result]).
+ tolerance: Decimal place to check accuracy to.
+ """
+ input_data = _generate_random_input_data(tflite_model)
+ tf_results = tf_eval_func(input_data)
+ tflite_results = _evaluate_tflite_model(tflite_model, input_data)
+ for tf_result, tflite_result in zip(tf_results, tflite_results):
+ np.testing.assert_almost_equal(tf_result, tflite_result, tolerance)
+
+
+def test_frozen_graph(filename,
+ input_arrays,
+ output_arrays,
+ input_shapes=None,
+ **kwargs):
+ """Validates the TensorFlow frozen graph converts to a TFLite model.
+
+ Converts the TensorFlow frozen graph to TFLite and checks the accuracy of the
+ model on random data.
+
+ Args:
+ filename: Full filepath of file containing frozen GraphDef.
+ input_arrays: List of input tensors to freeze graph with.
+ output_arrays: List of output tensors to freeze graph with.
+ input_shapes: Dict of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
+ Automatically determined when input shapes is None (e.g., {"foo" : None}).
+ (default None)
+ **kwargs: Additional arguments to be passed into the converter.
+ """
+ converter = _lite.TocoConverter.from_frozen_graph(filename, input_arrays,
+ output_arrays, input_shapes)
+ tflite_model = _convert(converter, **kwargs)
+
+ tf_eval_func = evaluate_frozen_graph(filename, input_arrays, output_arrays)
+ compare_models_random_data(tflite_model, tf_eval_func)
+
+
+def test_saved_model(directory, tag_set=None, signature_key=None, **kwargs):
+ """Validates the TensorFlow SavedModel converts to a TFLite model.
+
+ Converts the TensorFlow SavedModel to TFLite and checks the accuracy of the
+ model on random data.
+
+ Args:
+ directory: SavedModel directory to convert.
+ tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
+ analyze. All tags in the tag set must be present.
+ signature_key: Key identifying SignatureDef containing inputs and outputs.
+ **kwargs: Additional arguments to be passed into the converter.
+ """
+ converter = _lite.TocoConverter.from_saved_model(directory, tag_set,
+ signature_key)
+ tflite_model = _convert(converter, **kwargs)
+
+ tf_eval_func = evaluate_saved_model(directory, tag_set, signature_key)
+ compare_models_random_data(tflite_model, tf_eval_func)
+
+
+def test_keras_model(filename, **kwargs):
+ """Validates the tf.keras model converts to a TFLite model.
+
+ Converts the tf.keras model to TFLite and checks the accuracy of the model on
+ random data.
+
+ Args:
+ filename: Full filepath of HDF5 file containing the tf.keras model.
+ **kwargs: Additional arguments to be passed into the converter.
+ """
+ converter = _lite.TocoConverter.from_keras_model_file(filename)
+ tflite_model = _convert(converter, **kwargs)
+
+ tf_eval_func = evaluate_keras_model(filename)
+ compare_models_random_data(tflite_model, tf_eval_func)
diff --git a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py
new file mode 100644
index 0000000000..1498f86c6f
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py
@@ -0,0 +1,130 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for model_coverage_lib.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tempfile
+
+from tensorflow.contrib.lite.python import lite
+from tensorflow.contrib.lite.testing.model_coverage import model_coverage_lib as model_coverage
+from tensorflow.python import keras
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import saved_model
+from tensorflow.python.training.training_util import write_graph
+
+
+class EvaluateFrozenGraph(test.TestCase):
+
+ def _saveFrozenGraph(self, sess):
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+ write_graph(sess.graph_def, '', graph_def_file, False)
+ return graph_def_file
+
+ def testFloat(self):
+ with session.Session().as_default() as sess:
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ filename = self._saveFrozenGraph(sess)
+
+ model_coverage.test_frozen_graph(filename, ['Placeholder'], ['add'])
+
+ def testMultipleOutputs(self):
+ with session.Session().as_default() as sess:
+ in_tensor_1 = array_ops.placeholder(
+ shape=[1, 16], dtype=dtypes.float32, name='inputA')
+ in_tensor_2 = array_ops.placeholder(
+ shape=[1, 16], dtype=dtypes.float32, name='inputB')
+
+ weight = constant_op.constant(-1.0, shape=[16, 16])
+ bias = constant_op.constant(-1.0, shape=[16])
+ layer = math_ops.matmul(in_tensor_1, weight) + bias
+ _ = math_ops.reduce_mean(math_ops.square(layer - in_tensor_2))
+ filename = self._saveFrozenGraph(sess)
+
+ model_coverage.test_frozen_graph(filename, ['inputA', 'inputB'],
+ ['add', 'Mean'])
+
+
+class EvaluateSavedModel(test.TestCase):
+
+ def testFloat(self):
+ saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel')
+ with session.Session().as_default() as sess:
+ in_tensor_1 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
+ in_tensor_2 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
+ out_tensor = in_tensor_1 + in_tensor_2
+
+ inputs = {'x': in_tensor_1, 'y': in_tensor_2}
+ outputs = {'z': out_tensor}
+ saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
+ model_coverage.test_saved_model(saved_model_dir)
+
+
+class EvaluateKerasModel(test.TestCase):
+
+ def _getSingleInputKerasModel(self):
+ """Returns single input Sequential tf.keras model."""
+ keras.backend.clear_session()
+
+ xs = [-1, 0, 1, 2, 3, 4]
+ ys = [-3, -1, 1, 3, 5, 7]
+
+ model = keras.Sequential([keras.layers.Dense(units=1, input_shape=[1])])
+ model.compile(optimizer='sgd', loss='mean_squared_error')
+ model.train_on_batch(xs, ys)
+ return model
+
+ def _saveKerasModel(self, model):
+ try:
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
+ return keras_file
+
+ def testFloat(self):
+ model = self._getSingleInputKerasModel()
+ keras_file = self._saveKerasModel(model)
+
+ model_coverage.test_keras_model(keras_file)
+
+ def testPostTrainingQuantize(self):
+ model = self._getSingleInputKerasModel()
+ keras_file = self._saveKerasModel(model)
+
+ model_coverage.test_keras_model(keras_file, post_training_quantize=True)
+
+ def testConverterMode(self):
+ model = self._getSingleInputKerasModel()
+ keras_file = self._saveKerasModel(model)
+
+ model_coverage.test_keras_model(
+ keras_file, converter_mode=lite.ConverterMode.TOCO_FLEX)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_flags.h b/tensorflow/contrib/lite/testing/tflite_diff_flags.h
index 3874bc31d7..ad889a2f19 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_flags.h
+++ b/tensorflow/contrib/lite/testing/tflite_diff_flags.h
@@ -57,7 +57,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
"[optional] Number of full runs in each pass."),
tensorflow::Flag("delegate", &values.delegate,
"[optional] Delegate to use for executing ops. Must be "
- "`{\"\", EAGER}`"),
+ "`{\"\", FLEX}`"),
};
bool no_inputs = *argc == 1;
@@ -70,7 +70,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
values.input_layer_shape.empty() || values.output_layer.empty()) {
fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
return {};
- } else if (!(values.delegate == "" || values.delegate == "EAGER")) {
+ } else if (!(values.delegate == "" || values.delegate == "FLEX")) {
fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
return {};
}
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.h b/tensorflow/contrib/lite/testing/tflite_diff_util.h
index f67992139f..28b14bd143 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_util.h
+++ b/tensorflow/contrib/lite/testing/tflite_diff_util.h
@@ -45,7 +45,7 @@ struct DiffOptions {
// second pass does multiple inferences back to back.
int num_runs_per_pass;
// Path to the delegate library to be loaded in order to execute ops. Must be
- // `{"", EAGER}`.
+ // `{"", FLEX}`.
string delegate;
};
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
index 1836eb53b9..ef49e6f8bc 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.cc
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <iostream>
#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
+#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
#include "tensorflow/contrib/lite/testing/split.h"
namespace tflite {
@@ -138,8 +138,8 @@ class TfLiteDriver::Expectation {
TfLiteDriver::TfLiteDriver(bool use_nnapi, const string& delegate_name)
: use_nnapi_(use_nnapi) {
- if (delegate_name == "EAGER") {
- delegate_ = EagerDelegate::Create();
+ if (delegate_name == "FLEX") {
+ delegate_ = FlexDelegate::Create();
}
}
@@ -301,7 +301,7 @@ bool TfLiteDriver::CheckResults() {
}
void TfLiteDriver::ResetLSTMStateTensors() {
- interpreter_->ResetVariableTensorsToZero();
+ interpreter_->ResetVariableTensors();
}
} // namespace testing
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h
index aed35f877d..dc2a4e5877 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.h
+++ b/tensorflow/contrib/lite/testing/tflite_driver.h
@@ -17,7 +17,7 @@ limitations under the License.
#include <map>
-#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
+#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
@@ -53,7 +53,7 @@ class TfLiteDriver : public TestRunner {
class Expectation;
- std::unique_ptr<EagerDelegate> delegate_;
+ std::unique_ptr<FlexDelegate> delegate_;
bool use_nnapi_ = false;
std::unique_ptr<FlatBufferModel> model_;
std::unique_ptr<Interpreter> interpreter_;
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index f14dbc258b..2699ac76e1 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -248,9 +248,9 @@ struct ParsedTocoFlags {
Arg<int64> dedupe_array_min_size_bytes = Arg<int64>(64);
Arg<bool> split_tflite_lstm_inputs = Arg<bool>(true);
// WARNING: Experimental interface, subject to change
- Arg<bool> allow_eager_ops = Arg<bool>(false);
+ Arg<bool> allow_flex_ops = Arg<bool>(false);
// WARNING: Experimental interface, subject to change
- Arg<bool> force_eager_ops = Arg<bool>(false);
+ Arg<bool> force_flex_ops = Arg<bool>(false);
};
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index b52a79282c..61e9106783 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -470,6 +470,17 @@ void ConvertDepthwiseConvOperator(const Model& model,
strides.mutable_list()->add_i(src_op.stride_height);
strides.mutable_list()->add_i(src_op.stride_width);
strides.mutable_list()->add_i(1);
+ // TODO(b/116063589): To return a working TF GraphDef, we should be returning
+ // the correct SpaceToBatchNd and BatchToSpaceND operation before and after
+ // the conv since TF doesn't support dilations.
+ if ((src_op.dilation_width_factor != 1) ||
+ (src_op.dilation_height_factor != 1)) {
+ auto& dilations = (*dc2d_op->mutable_attr())["dilations"];
+ dilations.mutable_list()->add_i(1);
+ dilations.mutable_list()->add_i(src_op.dilation_height_factor);
+ dilations.mutable_list()->add_i(src_op.dilation_width_factor);
+ dilations.mutable_list()->add_i(1);
+ }
string padding;
if (src_op.padding.type == PaddingType::kSame) {
padding = "SAME";
@@ -1968,6 +1979,19 @@ void ConvertUnpackOperator(const Model& model, const UnpackOperator& src_op,
(*unpack_op->mutable_attr())["axis"].set_i(src_op.axis);
}
+void ConvertZerosLikeOperator(const Model& model,
+ const TensorFlowZerosLikeOperator& src_op,
+ const char* op_name, GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* zeros_like_op = tensorflow_graph->add_node();
+ zeros_like_op->set_op(op_name);
+ zeros_like_op->set_name(src_op.outputs[0]);
+ DCHECK_EQ(src_op.inputs.size(), 1);
+ *zeros_like_op->add_input() = src_op.inputs[0];
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*zeros_like_op->mutable_attr())["T"].set_type(data_type);
+}
+
void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -2233,6 +2257,10 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kUnpack) {
ConvertUnpackOperator(model, static_cast<const UnpackOperator&>(src_op),
"Unpack", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kZerosLike) {
+ ConvertZerosLikeOperator(
+ model, static_cast<const TensorFlowZerosLikeOperator&>(src_op),
+ "ZerosLike", tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}
diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
index 84680b968e..aba7536cbd 100644
--- a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
+++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
@@ -38,7 +38,7 @@ There are two approaches to running TOCO via command line.
examples below use `tflite_convert` for simplicity.
* Example: `tflite_convert --output_file=...`
* `bazel`: In order to run the latest version of TOCO, [clone the TensorFlow
- repository](https://www.tensorflow.org/install/install_sources#clone_the_tensorflow_repository)
+ repository](https://www.tensorflow.org/install/source)
and use `bazel`. This is the recommended approach for converting models that
utilize new features that were not supported by TOCO in TensorFlow 1.9.
* Example: `bazel run
diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md
index 51f808d4f0..8c31c3dca8 100644
--- a/tensorflow/contrib/lite/toco/g3doc/python_api.md
+++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md
@@ -39,13 +39,18 @@ The API for converting TensorFlow models to TensorFlow Lite as of TensorFlow 1.9
is `tf.contrib.lite.TocoConverter`. The API for calling the Python intepreter is
`tf.contrib.lite.Interpreter`.
+**NOTE**: As of TensorFlow 1.12, the API for converting TensorFlow models to
+TFLite will be renamed to `TFLiteConverter`. `TFLiteConverter` is semantically
+identically to `TocoConverter`. The API is available at
+`tf.contrib.lite.TFLiteConverter` as of the Sept 26 `tf-nightly`.
+
`TocoConverter` provides class methods based on the original format of the
model. `TocoConverter.from_session()` is available for GraphDefs.
`TocoConverter.from_saved_model()` is available for SavedModels.
`TocoConverter.from_keras_model_file()` is available for `tf.Keras` files.
-Example usages for simple float-point models are shown in [Basic
-Examples](#basic). Examples usages for more complex models is shown in [Complex
-Examples](#complex).
+Example usages for simple float-point models are shown in
+[Basic Examples](#basic). Examples usages for more complex models is shown in
+[Complex Examples](#complex).
**NOTE**: Currently, `TocoConverter` will cause a fatal error to the Python
interpreter when the conversion fails. This will be remedied as soon as
@@ -260,7 +265,7 @@ interpreter.allocate_tensors()
In order to run the latest version of the TOCO Python API, clone the TensorFlow
repository, configure the installation, and build and install the pip package.
Detailed instructions are available
-[here](https://www.tensorflow.org/install/install_sources).
+[here](https://www.tensorflow.org/install/source).
### Converting models prior to TensorFlow 1.9. <a name="pre-tensorflow-1.9"></a>
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index fdd0632451..4d213b3f9c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -133,7 +133,6 @@ DECLARE_GRAPH_TRANSFORMATION(MergeLstmCellInputs)
DECLARE_GRAPH_TRANSFORMATION(MergeReshapeIntoPrecedingTranspose)
DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1)
DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu)
-DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv)
DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator)
DECLARE_GRAPH_TRANSFORMATION(MoveBinaryOperatorBeforeReshape)
DECLARE_GRAPH_TRANSFORMATION(PropagateActivationFunctionIntoConstants)
@@ -266,6 +265,17 @@ class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation {
bool has_default_ranges_flag_ = false;
};
+class IdentifyDilatedConv : public GraphTransformation {
+ public:
+ bool Run(Model* model, std::size_t op_index) override;
+ const char* Name() const override { return "IdentifyDilatedConv"; }
+ bool identify_depthwise_conv() const { return identify_depthwise_conv_; }
+ void set_identify_depthwise_conv(bool val) { identify_depthwise_conv_ = val; }
+
+ private:
+ bool identify_depthwise_conv_ = true;
+};
+
#undef DECLARE_GRAPH_TRANSFORMATION
} // end namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
index d49857cfc2..aac77eb39e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
@@ -53,50 +53,11 @@ namespace toco {
// thrown in just for the extra headache. Padding adapts non-conforming input
// sizes, and can be discarded. The bias is necessary, so is kept.
-bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
- const auto it = model->operators.begin() + op_index;
- auto* stb_op = it->get();
-
- // 1. IDENTIFY OPERATORS
- // ***************************************************************************
- // SpaceToBatch Op.
- if (stb_op->type != OperatorType::kSpaceToBatchND) {
- return false;
- }
- if (stb_op->inputs.size() != 3) {
- return false;
- }
- CHECK_EQ(stb_op->outputs.size(), 1);
- // Extract the dilation factor from Input[1] of SpaceToBatch
- // TODO(mjmatthews): Support 2D dilation factors.
- const auto& block_shape_array = model->GetArray(stb_op->inputs[1]);
- if (!block_shape_array.buffer) {
- return false;
- }
- CHECK_EQ(block_shape_array.shape().dimensions_count(), 1);
- int dilation_factor =
- block_shape_array.Array::GetBuffer<ArrayDataType::kInt32>().data[0];
-
- // Expand Op
- auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]);
- if (!post_stb_op) {
- return false;
- }
- bool has_expand_op = false;
- if (post_stb_op->type == OperatorType::kExpandDims) {
- has_expand_op = true;
- CHECK_EQ(post_stb_op->inputs.size(), 2);
- CHECK_EQ(post_stb_op->outputs.size(), 1);
- }
-
- // Conv Op
- const string& input_of_conv_op =
- has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0];
- auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op);
- if (conv_base_op->type != OperatorType::kConv) {
- return false;
- }
- auto* conv_op = static_cast<ConvOperator*>(conv_base_op);
+template <typename T>
+bool ResolveDilatedConv(Model* model, Operator* conv_base_op, Operator* stb_op,
+ Operator* post_stb_op, bool has_expand_op,
+ int dilation_factor) {
+ auto* conv_op = static_cast<T*>(conv_base_op);
if (conv_op->inputs.size() != 2) {
// The conv op must only have weights, no bias.
return false;
@@ -158,8 +119,6 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
CHECK_EQ(bias_add_op->inputs.size(), 2);
CHECK_EQ(bias_add_op->outputs.size(), 1);
- LOG(INFO) << "Identified sub-network emulating dilated convolution.";
-
// 2. RE-WIRE OPERATORS
// ***************************************************************************
// Re-use the existing Conv2D op.
@@ -206,9 +165,71 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
DeleteArrayIfUnused(stb_op_inputs[1], model);
DeleteArrayIfUnused(stb_op_inputs[2], model);
- LOG(INFO) << "Replaced with Dilated Conv2D op outputting \""
- << conv_op->outputs[0] << "\".";
return true;
}
+bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ auto* stb_op = it->get();
+
+ // 1. IDENTIFY OPERATORS
+ // ***************************************************************************
+ // SpaceToBatch Op.
+ if (stb_op->type != OperatorType::kSpaceToBatchND) {
+ return false;
+ }
+ if (stb_op->inputs.size() != 3) {
+ return false;
+ }
+ CHECK_EQ(stb_op->outputs.size(), 1);
+ // Extract the dilation factor from Input[1] of SpaceToBatch
+ // TODO(mjmatthews): Support 2D dilation factors.
+ const auto& block_shape_array = model->GetArray(stb_op->inputs[1]);
+ if (!block_shape_array.buffer) {
+ return false;
+ }
+ CHECK_EQ(block_shape_array.shape().dimensions_count(), 1);
+ int dilation_factor =
+ block_shape_array.Array::GetBuffer<ArrayDataType::kInt32>().data[0];
+
+ // Expand Op
+ auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]);
+ if (!post_stb_op) {
+ return false;
+ }
+ bool has_expand_op = false;
+ if (post_stb_op->type == OperatorType::kExpandDims) {
+ has_expand_op = true;
+ CHECK_EQ(post_stb_op->inputs.size(), 2);
+ CHECK_EQ(post_stb_op->outputs.size(), 1);
+ }
+
+ // Conv Op
+ const string& input_of_conv_op =
+ has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0];
+ auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op);
+ bool changed = false;
+ if (conv_base_op->type == OperatorType::kConv) {
+ changed = ResolveDilatedConv<ConvOperator>(model, conv_base_op, stb_op,
+ post_stb_op, has_expand_op,
+ dilation_factor);
+ if (changed) {
+ LOG(INFO) << "Replaced sub-network with Dilated Conv2D op outputting \""
+ << conv_base_op->outputs[0] << "\".";
+ }
+ } else if (identify_depthwise_conv_ &&
+ conv_base_op->type == OperatorType::kDepthwiseConv) {
+ changed = ResolveDilatedConv<DepthwiseConvOperator>(
+ model, conv_base_op, stb_op, post_stb_op, has_expand_op,
+ dilation_factor);
+ if (changed) {
+ LOG(INFO)
+ << "Replaced sub-netork with Dilated DepthwiseConv2D op outputting \""
+ << conv_base_op->outputs[0] << "\".";
+ }
+ }
+
+ return changed;
+}
+
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index f103bb94ae..d056a8add7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -285,7 +285,8 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
const int kheight = weights_shape.dims(1);
const int kwidth = weights_shape.dims(2);
ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
- op->stride_height, 1, 1, op->padding.type,
+ op->stride_height, op->dilation_width_factor,
+ op->dilation_height_factor, op->padding.type,
model->GetArray(output_name).mutable_shape(),
&op->padding.GetOrCreateFixedPadding());
}
@@ -658,11 +659,16 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
}
}
auto& output_array = model->GetArray(op->outputs[0]);
- // Use 0 input as basis for output dimensions.
- const auto& first_input_array = model->GetArray(op->inputs[0]);
- output_array.copy_shape(first_input_array.shape());
- // Negative axis means the count starts at the back of the dims().
- if (op->axis < 0) op->axis += first_input_array.shape().dims().size();
+ // Use first non-empty input as basis for output dimensions.
+ for (const auto& input_name : op->inputs) {
+ const auto& input_array = model->GetArray(input_name);
+ if (input_array.shape().dimensions_count() > 0) {
+ output_array.copy_shape(input_array.shape());
+ // Negative axis means the count starts at the back of the dims().
+ if (op->axis < 0) op->axis += input_array.shape().dims().size();
+ break;
+ }
+ }
// Determine the concat size, and enfore that all inputs have
// the same dimensions count.
int concat_size = 0;
@@ -1655,6 +1661,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kLogicalAnd:
case OperatorType::kLogicalNot:
case OperatorType::kLogicalOr:
+ case OperatorType::kZerosLike:
ProcessSimpleOperator(model, op, 0);
break;
case OperatorType::kGather:
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
index 8266e2c205..8e150db6fa 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
@@ -25,29 +25,57 @@ limitations under the License.
namespace toco {
+namespace {
+
+void RenameArray(Model* model, const string& oldname,
+ const string& desired_newname) {
+ const string& newname = AvailableArrayName(*model, desired_newname);
+ auto& arrays = model->GetMutableArrayMap();
+ arrays[newname] = std::move(arrays[oldname]);
+ arrays.erase(oldname);
+ for (const auto& op : model->operators) {
+ for (string& input : op->inputs) {
+ if (input == oldname) {
+ input = newname;
+ }
+ }
+ for (string& output : op->outputs) {
+ if (output == oldname) {
+ output = newname;
+ }
+ }
+ }
+}
+
+} // namespace
+
// Reorder the elements of an input_array according to the input_axes_order and
// output_axes_order. Then adjust the shapes of the input and output arrays
// accordingly. Note that input_array must have a buffer (that is, it is a
// constant array).
template <typename T, ArrayDataType DataType>
void ReorderAxes(AxesOrder input_axes_order, AxesOrder output_axes_order,
- Array* input_array, Array* output_array) {
- CHECK(input_array->buffer->type == DataType);
- CHECK(!output_array->buffer);
- auto& input_data = input_array->GetMutableBuffer<DataType>().data;
- std::vector<T> reordered_data;
- reordered_data.resize(RequiredBufferSizeForShape(output_array->shape()));
+ const Array& input_array, Array* output_array) {
+ DCHECK(input_array.buffer->type == DataType);
+ DCHECK(!output_array->buffer);
+ const auto& input_data = input_array.GetBuffer<DataType>().data;
+ auto& output_data = output_array->GetMutableBuffer<DataType>().data;
+ output_data.resize(RequiredBufferSizeForShape(output_array->shape()));
// TODO(b/62904716) Shapes should be used directly.
- Shape input_shape = input_array->shape();
+ Shape input_shape = input_array.shape();
Shape output_shape = output_array->shape();
if (AxesCount(input_axes_order) == 2) {
UnextendShape(&input_shape, 2);
UnextendShape(&output_shape, 2);
}
ShuffleArray(input_shape, input_axes_order, output_axes_order, output_shape,
- input_data.data(), reordered_data.data());
- input_data = reordered_data;
- input_array->copy_shape(output_array->shape());
+ input_data.data(), output_data.data());
+ if (input_array.minmax) {
+ output_array->GetOrCreateMinMax() = input_array.GetMinMax();
+ }
+ if (input_array.narrow_range) {
+ output_array->narrow_range = true;
+ }
}
bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
@@ -57,8 +85,11 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
return false;
}
auto* reorder_op = static_cast<ReorderAxesOperator*>(op);
- const auto& input_array_name = reorder_op->inputs[0];
- const auto& output_array_name = reorder_op->outputs[0];
+
+ // Intentionally copies, not references.
+ const string input_array_name = reorder_op->inputs[0];
+ const string output_array_name = reorder_op->outputs[0];
+
auto& input_array = model->GetArray(input_array_name);
auto& output_array = model->GetArray(output_array_name);
if (!input_array.buffer) {
@@ -72,31 +103,23 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
if (input_array.buffer->type == ArrayDataType::kFloat) {
ReorderAxes<float, ArrayDataType::kFloat>(reorder_op->input_axes_order,
reorder_op->output_axes_order,
- &input_array, &output_array);
- } else if (input_array.buffer->type == ArrayDataType::kInt32) {
+ input_array, &output_array);
+ } else if (input_array.buffer->type == ArrayDataType::kUint8) {
+ // TODO(benoitjacob): This path seems unused.
+ // ReorderAxes is only used when importing from
+ // TensorFlow GraphDef, which does not support quantized nodes.
ReorderAxes<uint8, ArrayDataType::kUint8>(reorder_op->input_axes_order,
reorder_op->output_axes_order,
- &input_array, &output_array);
+ input_array, &output_array);
} else {
LOG(FATAL) << "Cannot ReorderAxes unless input buffer is float or uint8.";
}
- input_array.copy_shape(output_array.shape());
-
- // Update the edges of the graph to point to the input array
- for (const auto& other_op : model->operators) {
- for (auto& input : other_op->inputs) {
- if (input == output_array_name) {
- input = input_array_name;
- }
- }
- }
-
AddMessageF("Reordered axes for array %s", input_array_name);
- // Remove the op and output array.
- model->EraseArray(output_array_name);
- model->operators.erase(it);
+ DeleteOpAndArraysIfUnused(model, op);
+ RenameArray(model, output_array_name, input_array_name);
+
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
index fcf30bd347..65346c4fe4 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
@@ -24,6 +24,37 @@ limitations under the License.
namespace toco {
+namespace {
+
+TransposeOperator* FindTransposeOpWithInput(const Model& model,
+ const string& array_name) {
+ for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
+ Operator* op = it->get();
+ if (op->type != OperatorType::kTranspose) {
+ continue;
+ }
+ if (op->inputs[0] != array_name) {
+ continue;
+ }
+ const auto& permutation_array = model.GetArray(op->inputs[1]);
+ if (permutation_array.data_type != ArrayDataType::kInt32) {
+ continue;
+ }
+ const auto& permutation_data =
+ permutation_array.GetBuffer<ArrayDataType::kInt32>().data;
+ if (permutation_data.size() != 2) {
+ continue;
+ }
+ if (permutation_data[0] != 1 || permutation_data[1] != 0) {
+ continue;
+ }
+ return static_cast<TransposeOperator*>(op);
+ }
+ return nullptr;
+}
+
+} // namespace
+
bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
auto matmul_it = model->operators.begin() + op_index;
if (matmul_it->get()->type != OperatorType::kMatMul) {
@@ -37,7 +68,13 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
// TransposeOperator. However, the second input is supposed to be 2D, so we
// can actually handle transposition of that matrix, which happens to be more
// common anyway.
- CHECK(!matmul_op->transpose_a);
+ if (matmul_op->transpose_a) {
+ AddMessageF(
+ "Not replacing %s by a FullyConnected operator, because it has "
+ "the transpose_a attribute",
+ LogName(*matmul_op));
+ return false;
+ }
// Reorder the axes on the second input. TensorFlow uses row-major ordering
// on both inputs, however this is inefficient for the FullyConnected
@@ -46,18 +83,35 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
string input_lhs = matmul_op->inputs[0];
string input_rhs = matmul_op->inputs[1];
if (!matmul_op->transpose_b) {
- auto* transpose_op = new TransposeOperator;
- transpose_op->inputs = {
- matmul_op->inputs[1],
- CreateInt32Array(model,
- AvailableArrayName(
- *model, matmul_op->inputs[1] + "/transpose/perm"),
- {1, 0})};
- transpose_op->outputs = {
- AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose")};
- model->GetOrCreateArray(transpose_op->outputs[0]);
- model->operators.emplace(matmul_it, transpose_op);
-
+ // Need to transpose input_rhs, by inserting a TransposeOperator.
+ // First, check if there already is a TransposeOperator transposing that
+ // array, so we can just reuse it.
+ auto* transpose_op = FindTransposeOpWithInput(*model, input_rhs);
+ if (!transpose_op) {
+ AddMessageF(
+ "While replacing %s by a FullyConnected operator, created new "
+ "Transpose op wrapping RHS input array %s",
+ LogName(*matmul_op), input_rhs);
+ // No such TransposeOperator found. Create one now.
+ transpose_op = new TransposeOperator;
+ transpose_op->inputs = {
+ input_rhs,
+ CreateInt32Array(
+ model, AvailableArrayName(*model, input_rhs + "/transpose/perm"),
+ {1, 0})};
+ transpose_op->outputs = {
+ AvailableArrayName(*model, input_rhs + "/transpose")};
+ model->GetOrCreateArray(transpose_op->outputs[0]);
+ model->operators.emplace(matmul_it, transpose_op);
+ // Sanity check
+ DCHECK_EQ(transpose_op, FindTransposeOpWithInput(*model, input_rhs));
+ } else {
+ AddMessageF(
+ "While replacing %s by a FullyConnected operator, reused existing "
+ "Transpose op wrapping RHS input array %s",
+ LogName(*matmul_op), input_rhs);
+ }
+ // Re-wire: have the matmul consume the transposed array.
input_rhs = transpose_op->outputs[0];
}
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 2ccfd36b7c..5eaf6e27fc 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -641,6 +641,23 @@ tensorflow::Status ConvertDepthwiseConvOperator(
CHECK_EQ(strides.i(3), 1);
conv->stride_height = strides.i(1);
conv->stride_width = strides.i(2);
+ if (HasAttr(node, "dilations")) {
+ const auto& dilations = GetListAttr(node, "dilations");
+ TF_RETURN_IF_ERROR(
+ ExpectValue(dilations.i_size(), 4, "number of dilations"));
+ if (dilations.i(0) != 1 || dilations.i(3) != 1) {
+ return tensorflow::errors::InvalidArgument(absl::StrCat(
+ "Can only import Conv ops with dilation along the height "
+ "(1st) or width (2nd) axis. TensorFlow op \"",
+ node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
+ dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
+ }
+ conv->dilation_height_factor = dilations.i(1);
+ conv->dilation_width_factor = dilations.i(2);
+ } else {
+ conv->dilation_height_factor = 1;
+ conv->dilation_width_factor = 1;
+ }
const auto& padding = GetStringAttr(node, "padding");
if (padding == "SAME") {
conv->padding.type = PaddingType::kSame;
@@ -2065,6 +2082,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"TopKV2", ConvertTopKV2Operator},
{"Transpose", ConvertSimpleOperator<TransposeOperator, 2>},
{"Unpack", ConvertUnpackOperator},
+ {"ZerosLike", ConvertSimpleOperator<TensorFlowZerosLikeOperator, 1>},
});
}
@@ -2105,9 +2123,9 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
Model* model = new Model;
internal::ConverterMapType converter_map;
- // This is used for the TFLite "Full Eager Mode" conversion. All the ops are
+ // This is used for the TFLite "Full Flex Mode" conversion. All the ops are
// imported as `TensorFlowUnsupportedOperator`, and later all these ops are
- // converted to TFLite Eager ops.
+ // converted to TFLite Flex ops.
if (!tf_import_flags.import_all_ops_as_unsupported) {
converter_map = internal::GetTensorFlowNodeConverterMap();
}
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.h b/tensorflow/contrib/lite/toco/import_tensorflow.h
index 7db23f2d44..c5ff96956a 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.h
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.h
@@ -30,7 +30,7 @@ struct TensorFlowImportFlags {
// Do not recognize any op and import all ops as
// `TensorFlowUnsupportedOperator`. This is used to populated with the
- // `force_eager_ops` flag.
+ // `force_flex_ops` flag.
bool import_all_ops_as_unsupported = false;
};
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 164b70f2df..6e207fdf54 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -150,6 +150,7 @@ enum class OperatorType : uint8 {
kLogicalOr,
kCTCBeamSearchDecoder,
kUnpack,
+ kZerosLike,
};
// Helper to deal with TensorFlow arrays using a different ordering of
@@ -1849,6 +1850,16 @@ struct UnpackOperator : Operator {
ArrayDataType dtype = ArrayDataType::kNone;
};
+// ZerosLike operator:
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: tf.zeros_like
+struct TensorFlowZerosLikeOperator : Operator {
+ TensorFlowZerosLikeOperator() : Operator(OperatorType::kZerosLike) {}
+};
+
// Alloc's are used for transient arrays only. An Alloc specifies which interval
// of the "transient_data" workspace buffer passed to inference functions, is to
// be used for the transient array at hand. The 'start' and 'end' values are
@@ -2073,6 +2084,7 @@ class Model {
}
}
const ArrayMap& GetArrayMap() const { return arrays; }
+ ArrayMap& GetMutableArrayMap() { return arrays; }
int64 ArithmeticOpsCount() const { return ops_count; }
diff --git a/tensorflow/contrib/lite/toco/python/BUILD b/tensorflow/contrib/lite/toco/python/BUILD
index 33c5b16462..cf97ba7084 100644
--- a/tensorflow/contrib/lite/toco/python/BUILD
+++ b/tensorflow/contrib/lite/toco/python/BUILD
@@ -4,6 +4,7 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
load("//tensorflow:tensorflow.bzl", "tf_py_test")
+load("//tensorflow:tensorflow.bzl", "py_binary")
cc_library(
name = "toco_python_api",
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index fee10b1dff..0c9fac249c 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -50,16 +50,16 @@ namespace {
details::OperatorKey GetOperatorKey(
const ::toco::Operator& op,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
- bool allow_eager_ops) {
+ bool allow_flex_ops) {
string custom_code;
if (op.type == OperatorType::kUnsupported) {
const TensorFlowUnsupportedOperator& unsupported_op =
static_cast<const TensorFlowUnsupportedOperator&>(op);
- // TODO(b/113715895): When `allow_eager_ops` is on, for now there's no way
+ // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way
// to populate a regular custom op. We need to find a way to fix this.
- if (allow_eager_ops) {
- custom_code = string(::tflite::kEagerCustomCodePrefix) +
+ if (allow_flex_ops) {
+ custom_code = string(::tflite::kFlexCustomCodePrefix) +
unsupported_op.tensorflow_op;
} else {
custom_code = unsupported_op.tensorflow_op;
@@ -101,11 +101,11 @@ void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) {
void LoadOperatorsMap(
const Model& model, OperatorsMap* operators_map,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
- bool allow_eager_ops) {
+ bool allow_flex_ops) {
// First find a list of unique operator types.
std::set<OperatorKey> keys;
for (const auto& op : model.operators) {
- keys.insert(GetOperatorKey(*op, ops_by_type, allow_eager_ops));
+ keys.insert(GetOperatorKey(*op, ops_by_type, allow_flex_ops));
}
// Now assign indices to them and fill in the map.
int index = 0;
@@ -216,7 +216,7 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
for (const auto& op : model.operators) {
const details::OperatorKey operator_key =
- GetOperatorKey(*op, ops_by_type, params.allow_eager_ops);
+ GetOperatorKey(*op, ops_by_type, params.allow_flex_ops);
int op_index = operators_map.at(operator_key);
int op_version = operator_key.version;
@@ -281,7 +281,7 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
}
int op_index = operators_map.at(
- GetOperatorKey(*op, ops_by_type, params.allow_eager_ops));
+ GetOperatorKey(*op, ops_by_type, params.allow_flex_ops));
auto tflite_op_it = ops_by_type.find(op->type);
BaseOperator* tflite_op = tflite_op_it == ops_by_type.end()
@@ -334,7 +334,7 @@ Offset<Vector<Offset<Buffer>>> ExportBuffers(
void Export(const Model& model, string* output_file_contents,
const ExportParams& params) {
- const auto ops_by_type = BuildOperatorByTypeMap(params.allow_eager_ops);
+ const auto ops_by_type = BuildOperatorByTypeMap(params.allow_flex_ops);
Export(model, output_file_contents, params, ops_by_type);
}
@@ -349,7 +349,7 @@ void Export(
details::OperatorsMap operators_map;
details::LoadOperatorsMap(model, &operators_map, ops_by_type,
- params.allow_eager_ops);
+ params.allow_flex_ops);
std::vector<const Array*> buffers_to_write;
Array empty_array;
@@ -388,7 +388,7 @@ void Export(
"the standard TensorFlow Lite runtime. If you have a custom "
"implementation for them you can disable this error with "
"--allow_custom_ops, or by setting allow_custom_ops=True "
- "when calling tf.contrib.lite.TocoConverter(). Here is a list "
+ "when calling tf.contrib.lite.TFLiteConverter(). Here is a list "
"of operators for which you will need custom implementations: "
<< absl::StrJoin(error_summary_final, ", ") << ".";
}
diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h
index b070a38768..29d6de4049 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.h
+++ b/tensorflow/contrib/lite/toco/tflite/export.h
@@ -26,7 +26,7 @@ namespace tflite {
// The parameters for exporting a TFLite model.
struct ExportParams {
bool allow_custom_ops = false;
- bool allow_eager_ops = false;
+ bool allow_flex_ops = false;
bool quantize_weights = false;
};
@@ -121,7 +121,7 @@ void LoadTensorsMap(const Model& model, TensorsMap* tensors_map);
void LoadOperatorsMap(
const Model& model, OperatorsMap* operators_map,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
- bool allow_eager_ops);
+ bool allow_flex_ops);
} // namespace details
} // namespace tflite
diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc
index 8d4d197c46..93882a91a7 100644
--- a/tensorflow/contrib/lite/toco/tflite/export_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc
@@ -105,7 +105,7 @@ TEST_F(ExportTest, LoadOperatorsMap) {
details::OperatorsMap operators;
const auto ops_by_type = BuildOperatorByTypeMap();
- // TODO(ycling): Add a test for allow_eager_ops.
+ // TODO(ycling): Add a test for allow_flex_ops.
details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]);
EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]);
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 1061e7c7c4..9addbb81e7 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -1160,8 +1160,8 @@ class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
class TensorFlowUnsupported : public BaseOperator {
public:
TensorFlowUnsupported(const string& name, OperatorType type,
- bool allow_eager_ops)
- : BaseOperator(name, type), allow_eager_ops_(allow_eager_ops) {}
+ bool allow_flex_ops)
+ : BaseOperator(name, type), allow_flex_ops_(allow_flex_ops) {}
Options Serialize(const Operator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
@@ -1177,9 +1177,9 @@ class TensorFlowUnsupported : public BaseOperator {
std::unique_ptr<Operator> Deserialize(
const BuiltinOptions* builtin_options,
const CustomOptions* custom_options) const override {
- // Deserializing Eager ops doesn't work now.
+ // Deserializing Flex ops doesn't work now.
// TODO(ycling): Revisit and decide if we should fix the flow for importing
- // TFLite models with Eager ops.
+ // TFLite models with Flex ops.
auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
if (custom_options) {
auto flexbuffer_map =
@@ -1200,13 +1200,13 @@ class TensorFlowUnsupported : public BaseOperator {
return std::unique_ptr<flexbuffers::Builder>();
}
- if (allow_eager_ops_) {
+ if (allow_flex_ops_) {
fbb->Vector([&]() {
fbb->String(node_def.op());
fbb->String(op.tensorflow_node_def);
});
fbb->Finish();
- LOG(INFO) << "Writing eager op: " << node_def.op();
+ LOG(INFO) << "Writing flex op: " << node_def.op();
return std::unique_ptr<flexbuffers::Builder>(fbb.release());
}
@@ -1260,6 +1260,10 @@ class TensorFlowUnsupported : public BaseOperator {
return std::unique_ptr<flexbuffers::Builder>(fbb.release());
}
+// TODO(wvo): hack to make this code compile with 2 different API versions.
+// Please remove once OS/internal versions are in sync.
+// See hardcoded values in the switch below.
+
void ReadOptions(const flexbuffers::Map& m,
TensorFlowUnsupportedOperator* op) const {
::tensorflow::NodeDef node_def;
@@ -1270,16 +1274,16 @@ class TensorFlowUnsupported : public BaseOperator {
const auto key = keys[i].AsKey();
const auto& value = m[key];
switch (value.GetType()) {
- case flexbuffers::TYPE_STRING:
+ case 5: // flexbuffers::FBT_STRING:
(*attr)[key].set_s(value.AsString().c_str());
break;
- case flexbuffers::TYPE_INT:
+ case 1: // flexbuffers::FBT_INT:
(*attr)[key].set_i(value.AsInt64());
break;
- case flexbuffers::TYPE_FLOAT:
+ case 3: // flexbuffers::FBT_FLOAT:
(*attr)[key].set_f(value.AsFloat());
break;
- case flexbuffers::TYPE_BOOL:
+ case 26: // flexbuffers::FBT_BOOL:
(*attr)[key].set_b(value.AsBool());
if (string(key) == "_output_quantized") {
op->quantized = value.AsBool();
@@ -1288,7 +1292,7 @@ class TensorFlowUnsupported : public BaseOperator {
op->support_output_type_float_in_quantized_op = value.AsBool();
}
break;
- case flexbuffers::TYPE_VECTOR_INT: {
+ case 11: { // flexbuffers::FBT_VECTOR_INT: {
auto* list = (*attr)[key].mutable_list();
const auto& vector = value.AsTypedVector();
for (size_t i = 0; i < vector.size(); i++) {
@@ -1312,13 +1316,13 @@ class TensorFlowUnsupported : public BaseOperator {
}
private:
- const bool allow_eager_ops_;
+ const bool allow_flex_ops_;
};
namespace {
// Build a vector containing all the known operators.
std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
- bool allow_eager_ops = false) {
+ bool allow_flex_ops = false) {
std::vector<std::unique_ptr<BaseOperator>> ops;
using tensorflow::MakeUnique;
// Builtin Operators.
@@ -1430,7 +1434,7 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
ops.push_back(MakeUnique<CTCBeamSearchDecoder>(
"CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
ops.push_back(MakeUnique<TensorFlowUnsupported>(
- "TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported, allow_eager_ops));
+ "TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported, allow_flex_ops));
// There operators are supported by Toco, but not by TF Lite, and has no
// attributes.
@@ -1500,17 +1504,19 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
"RSQRT", OperatorType::kRsqrt));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
"SQUARE", OperatorType::kSquare));
+ ops.push_back(MakeUnique<SimpleOperator<TensorFlowZerosLikeOperator>>(
+ "ZEROS_LIKE", OperatorType::kZerosLike));
return ops;
}
} // namespace
std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
- bool allow_eager_ops) {
+ bool allow_flex_ops) {
std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
std::vector<std::unique_ptr<BaseOperator>> ops =
- BuildOperatorList(allow_eager_ops);
+ BuildOperatorList(allow_flex_ops);
for (auto& op : ops) {
result[op->type()] = std::move(op);
}
@@ -1519,11 +1525,11 @@ std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
}
std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
- bool allow_eager_ops) {
+ bool allow_flex_ops) {
std::map<string, std::unique_ptr<BaseOperator>> result;
std::vector<std::unique_ptr<BaseOperator>> ops =
- BuildOperatorList(allow_eager_ops);
+ BuildOperatorList(allow_flex_ops);
for (auto& op : ops) {
result[op->name()] = std::move(op);
}
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h
index 702fb28ea6..13d9f6c49a 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.h
+++ b/tensorflow/contrib/lite/toco/tflite/operator.h
@@ -26,15 +26,15 @@ namespace tflite {
class BaseOperator;
// Return a map contained all know TF Lite Operators, keyed by their names.
-// TODO(ycling): The pattern to propagate parameters (e.g. allow_eager_ops)
+// TODO(ycling): The pattern to propagate parameters (e.g. allow_flex_ops)
// is ugly here. Consider refactoring.
std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
- bool allow_eager_ops = false);
+ bool allow_flex_ops = false);
// Return a map contained all know TF Lite Operators, keyed by the type of
// their tf.mini counterparts.
std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
- bool allow_eager_ops = false);
+ bool allow_flex_ops = false);
// These are the flatbuffer types for custom and builtin options.
using CustomOptions = flatbuffers::Vector<uint8_t>;
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index 72e50a9aed..0bc591e647 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -146,6 +146,8 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv);
CheckSimpleOperator<TensorFlowSquareOperator>("SQUARE",
OperatorType::kSquare);
+ CheckSimpleOperator<TensorFlowZerosLikeOperator>("ZEROS_LIKE",
+ OperatorType::kZerosLike);
}
TEST_F(OperatorTest, BuiltinAdd) {
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
index b6aebc0470..cff79776bc 100644
--- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -167,11 +167,11 @@ bool ParseTocoFlagsFromCommandLineFlags(
"converted float model. Model size will be reduced and there will "
"be latency improvements (at the cost of accuracy)."),
// WARNING: Experimental interface, subject to change
- Flag("allow_eager_ops", parsed_flags.allow_eager_ops.bind(),
- parsed_flags.allow_eager_ops.default_value(), ""),
+ Flag("allow_flex_ops", parsed_flags.allow_flex_ops.bind(),
+ parsed_flags.allow_flex_ops.default_value(), ""),
// WARNING: Experimental interface, subject to change
- Flag("force_eager_ops", parsed_flags.force_eager_ops.bind(),
- parsed_flags.force_eager_ops.default_value(), "")};
+ Flag("force_flex_ops", parsed_flags.force_flex_ops.bind(),
+ parsed_flags.force_flex_ops.default_value(), "")};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
if (asked_for_help) {
@@ -266,15 +266,15 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone);
READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone);
READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone);
- READ_TOCO_FLAG(allow_eager_ops, FlagRequirement::kNone);
- READ_TOCO_FLAG(force_eager_ops, FlagRequirement::kNone);
+ READ_TOCO_FLAG(allow_flex_ops, FlagRequirement::kNone);
+ READ_TOCO_FLAG(force_flex_ops, FlagRequirement::kNone);
- if (parsed_toco_flags.force_eager_ops.value() &&
- !parsed_toco_flags.allow_eager_ops.value()) {
- // TODO(ycling): Consider to enforce `allow_eager_ops` when
- // `force_eager_ops` is true.
- LOG(WARNING) << "--force_eager_ops should always be used with "
- "--allow_eager_ops.";
+ if (parsed_toco_flags.force_flex_ops.value() &&
+ !parsed_toco_flags.allow_flex_ops.value()) {
+ // TODO(ycling): Consider to enforce `allow_flex_ops` when
+ // `force_flex_ops` is true.
+ LOG(WARNING) << "--force_flex_ops should always be used with "
+ "--allow_flex_ops.";
}
// Deprecated flag handling.
diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto
index 53d60fed05..ca3e64485e 100644
--- a/tensorflow/contrib/lite/toco/toco_flags.proto
+++ b/tensorflow/contrib/lite/toco/toco_flags.proto
@@ -190,16 +190,16 @@ message TocoFlags {
// (at the cost of accuracy).
optional bool post_training_quantize = 26 [default = false];
- // When enabled, unsupported ops will be converted to TFLite Eager ops.
+ // When enabled, unsupported ops will be converted to TFLite Flex ops.
// TODO(ycling): Consider to rename the following 2 flags and don't call it
- // "Eager".
- // `allow_eager_ops` should always be used with `allow_custom_ops`.
+ // "Flex".
+ // `allow_flex_ops` should always be used with `allow_custom_ops`.
// WARNING: Experimental interface, subject to change
- optional bool allow_eager_ops = 27 [default = false];
+ optional bool allow_flex_ops = 27 [default = false];
- // When enabled, all TensorFlow ops will be converted to TFLite Eager
- // ops directly. This will force `allow_eager_ops` to true.
- // `force_eager_ops` should always be used with `allow_eager_ops`.
+ // When enabled, all TensorFlow ops will be converted to TFLite Flex
+ // ops directly. This will force `allow_flex_ops` to true.
+ // `force_flex_ops` should always be used with `allow_flex_ops`.
// WARNING: Experimental interface, subject to change
- optional bool force_eager_ops = 28 [default = false];
+ optional bool force_flex_ops = 28 [default = false];
}
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index a7c17156b1..106494f354 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -101,7 +101,6 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ResolveTensorFlowSwitch);
transformations->Add(new ResolveTensorFlowConcat);
transformations->Add(new ResolveMultiplyByZero);
- transformations->Add(new IdentifyDilatedConv);
transformations->Add(new IdentifyL2Normalization);
transformations->Add(new IdentifyL2Pool);
transformations->Add(new IdentifyRelu1);
@@ -199,7 +198,7 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
: (toco_flags.output_format() != TENSORFLOW_GRAPHDEF);
tf_import_flags.import_all_ops_as_unsupported =
- toco_flags.force_eager_ops();
+ toco_flags.force_flex_ops();
model = ImportTensorFlowGraphDef(model_flags, tf_import_flags,
input_file_contents);
@@ -282,6 +281,14 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
}
}
transformations.Add(new ResolveConstantConcatenation);
+ // TODO(b/116063589): TF GraphDef doesn't support dilations on its depthwise
+ // conv, so we need to make sure we don't convert to dilated depthwise conv
+ // when outputing to TF GraphDef.
+ auto* identify_dilated_conv = new IdentifyDilatedConv;
+ if (output_format == TENSORFLOW_GRAPHDEF) {
+ identify_dilated_conv->set_identify_depthwise_conv(false);
+ }
+ transformations.Add(identify_dilated_conv);
RunGraphTransformations(model, "general graph transformations",
transformations);
@@ -367,9 +374,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
}
// Deduplicate large constant arrays.
- if (toco_flags.has_dedupe_array_min_size_bytes()) {
- DedupeConstantArrays(model, toco_flags.dedupe_array_min_size_bytes());
- }
+ DedupeConstantArrays(model, toco_flags.dedupe_array_min_size_bytes());
LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model);
@@ -404,9 +409,9 @@ void Export(const TocoFlags& toco_flags, const Model& model,
case TFLITE: {
toco::tflite::ExportParams params;
- // Always allow custom ops when eager ops are allowed.
- if (toco_flags.force_eager_ops() || toco_flags.allow_eager_ops()) {
- params.allow_eager_ops = true;
+ // Always allow custom ops when flex ops are allowed.
+ if (toco_flags.force_flex_ops() || toco_flags.allow_flex_ops()) {
+ params.allow_flex_ops = true;
params.allow_custom_ops = true;
} else if (allow_custom_ops) {
params.allow_custom_ops = true;
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 6ab93d9316..4a1ae35cb5 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -406,6 +406,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(LogicalOr)
HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder)
HANDLE_OPERATORTYPENAME_CASE(Unpack)
+ HANDLE_OPERATORTYPENAME_CASE(ZerosLike)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE
diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD
index dc97d22401..502e181139 100644
--- a/tensorflow/contrib/lite/tools/benchmark/BUILD
+++ b/tensorflow/contrib/lite/tools/benchmark/BUILD
@@ -36,11 +36,11 @@ cc_binary(
)
cc_binary(
- name = "benchmark_model_plus_eager",
+ name = "benchmark_model_plus_flex",
srcs = [
"benchmark_main.cc",
],
- copts = common_copts + ["-DTFLITE_EXTENDED"],
+ copts = common_copts + ["-DTFLITE_FLEX"],
linkopts = tflite_linkopts() + select({
"//tensorflow:android": [
"-pie", # Android 5.0 and later supports only PIE
@@ -49,7 +49,7 @@ cc_binary(
"//conditions:default": [],
}),
deps = [
- ":benchmark_tflite_model_plus_eager_lib",
+ ":benchmark_tflite_model_plus_flex_lib",
":logging",
],
)
@@ -111,19 +111,19 @@ cc_library(
)
cc_library(
- name = "benchmark_tflite_model_plus_eager_lib",
+ name = "benchmark_tflite_model_plus_flex_lib",
srcs = [
"benchmark_tflite_model.cc",
"logging.h",
],
hdrs = ["benchmark_tflite_model.h"],
- copts = common_copts + ["-DTFLITE_EXTENDED"],
+ copts = common_copts + ["-DTFLITE_FLEX"],
deps = [
":benchmark_model_lib",
":logging",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
- "//tensorflow/contrib/lite/delegates/eager:delegate",
+ "//tensorflow/contrib/lite/delegates/flex:delegate",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/profiling:profile_summarizer",
],
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
index ef4f0fa80d..463d5993f4 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -23,9 +23,9 @@ limitations under the License.
#include <unordered_set>
#include <vector>
-#ifdef TFLITE_EXTENDED
-#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
-#endif // TFLITE_EXTENDED
+#ifdef TFLITE_FLEX
+#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
+#endif // TFLITE_FLEX
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/op_resolver.h"
@@ -305,14 +305,14 @@ void BenchmarkTfLiteModel::Init() {
interpreter->UseNNAPI(use_nnapi);
-#ifdef TFLITE_EXTENDED
- TFLITE_LOG(INFO) << "Instantiating Eager Delegate";
- delegate_ = EagerDelegate::Create();
+#ifdef TFLITE_FLEX
+ TFLITE_LOG(INFO) << "Instantiating Flex Delegate";
+ delegate_ = FlexDelegate::Create();
if (delegate_) {
interpreter->ModifyGraphWithDelegate(delegate_.get(),
/*allow_dynamic_tensors=*/true);
}
-#endif // TFLITE_EXTENDED
+#endif // TFLITE_FLEX
auto interpreter_inputs = interpreter->inputs();
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
index 8541512bc8..b091e18a29 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
@@ -20,9 +20,9 @@ limitations under the License.
#include <string>
#include <vector>
-#ifdef TFLITE_EXTENDED
-#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
-#endif // TFLITE_EXTENDED
+#ifdef TFLITE_FLEX
+#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
+#endif // TFLITE_FLEX
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/profiling/profile_summarizer.h"
#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h"
@@ -73,9 +73,9 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
void PrepareInputsAndOutputs() override;
private:
-#ifdef TFLITE_EXTENDED
- std::unique_ptr<EagerDelegate> delegate_;
-#endif // TFLITE_EXTENDED
+#ifdef TFLITE_FLEX
+ std::unique_ptr<FlexDelegate> delegate_;
+#endif // TFLITE_FLEX
std::unique_ptr<tflite::FlatBufferModel> model;
std::unique_ptr<tflite::Interpreter> interpreter;
std::vector<InputLayerInfo> inputs;
diff --git a/tensorflow/contrib/lite/tools/make/Makefile b/tensorflow/contrib/lite/tools/make/Makefile
index 59bdb10811..16012a3fb1 100644
--- a/tensorflow/contrib/lite/tools/make/Makefile
+++ b/tensorflow/contrib/lite/tools/make/Makefile
@@ -30,6 +30,7 @@ INCLUDES := \
-I$(MAKEFILE_DIR)/../../../../../../ \
-I$(MAKEFILE_DIR)/downloads/ \
-I$(MAKEFILE_DIR)/downloads/eigen \
+-I$(MAKEFILE_DIR)/downloads/absl \
-I$(MAKEFILE_DIR)/downloads/gemmlowp \
-I$(MAKEFILE_DIR)/downloads/neon_2_sse \
-I$(MAKEFILE_DIR)/downloads/farmhash/src \
diff --git a/tensorflow/contrib/lite/tools/make/download_dependencies.sh b/tensorflow/contrib/lite/tools/make/download_dependencies.sh
index 29afa45133..3570f9a38d 100755
--- a/tensorflow/contrib/lite/tools/make/download_dependencies.sh
+++ b/tensorflow/contrib/lite/tools/make/download_dependencies.sh
@@ -35,7 +35,7 @@ GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.g
ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)"
NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip"
FARMHASH_URL="https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz"
-FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/v1.8.0.zip"
+FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/1f5eae5d6a135ff6811724f6c57f911d1f46bb15.tar.gz"
FFT2D_URL="https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz"
# TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64,
diff --git a/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb b/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb
index 4929133bda..80cdb2f080 100644
--- a/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb
+++ b/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb
@@ -36,7 +36,7 @@
"source": [
"## Overview\n",
"\n",
- "[TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/) now supports\n",
+ "[TensorFlow Lite](https://www.tensorflow.org/lite/) now supports\n",
"converting weights to 8 bit precision as part of model conversion from\n",
"tensorflow graphdefs to TFLite's flat buffer format. Weight quantization\n",
"achieves a 4x reduction in the model size. In addition, TFLite supports on the\n",
@@ -542,7 +542,7 @@
},
"outputs": [],
"source": [
- "print(eval_model(interpreter_quant, mnist_ds))"
+ "print(eval_model(interpreter, mnist_ds))"
]
},
{
diff --git a/tensorflow/contrib/lite/util.cc b/tensorflow/contrib/lite/util.cc
index 7950653da9..6aa35b5227 100644
--- a/tensorflow/contrib/lite/util.cc
+++ b/tensorflow/contrib/lite/util.cc
@@ -18,9 +18,9 @@ limitations under the License.
namespace tflite {
-bool IsEagerOp(const char* custom_name) {
- return custom_name && strncmp(custom_name, kEagerCustomCodePrefix,
- strlen(kEagerCustomCodePrefix)) == 0;
+bool IsFlexOp(const char* custom_name) {
+ return custom_name && strncmp(custom_name, kFlexCustomCodePrefix,
+ strlen(kFlexCustomCodePrefix)) == 0;
}
TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input) {
diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h
index 6d81f844f8..31292a6f81 100644
--- a/tensorflow/contrib/lite/util.h
+++ b/tensorflow/contrib/lite/util.h
@@ -26,15 +26,15 @@ limitations under the License.
namespace tflite {
-// The prefix of Eager op custom code.
+// The prefix of Flex op custom code.
// This will be matched agains the `custom_code` field in `OperatorCode`
// Flatbuffer Table.
// WARNING: This is an experimental API and subject to change.
-constexpr char kEagerCustomCodePrefix[] = "Eager";
+constexpr char kFlexCustomCodePrefix[] = "Flex";
// Checks whether the prefix of the custom name indicates the operation is an
-// Eager operation.
-bool IsEagerOp(const char* custom_name);
+// Flex operation.
+bool IsFlexOp(const char* custom_name);
// Converts a `std::vector` to a `TfLiteIntArray`. The caller takes ownership
// of the returned pointer.
diff --git a/tensorflow/contrib/lite/util_test.cc b/tensorflow/contrib/lite/util_test.cc
index c5c1709f1d..25f3aded71 100644
--- a/tensorflow/contrib/lite/util_test.cc
+++ b/tensorflow/contrib/lite/util_test.cc
@@ -41,14 +41,14 @@ TEST(ConvertVectorToTfLiteIntArray, TestWithEmptyVector) {
TfLiteIntArrayFree(output);
}
-TEST(UtilTest, IsEagerOp) {
- EXPECT_TRUE(IsEagerOp("Eager"));
- EXPECT_TRUE(IsEagerOp("EagerOp"));
- EXPECT_FALSE(IsEagerOp("eager"));
- EXPECT_FALSE(IsEagerOp("Eage"));
- EXPECT_FALSE(IsEagerOp("OpEager"));
- EXPECT_FALSE(IsEagerOp(nullptr));
- EXPECT_FALSE(IsEagerOp(""));
+TEST(UtilTest, IsFlexOp) {
+ EXPECT_TRUE(IsFlexOp("Flex"));
+ EXPECT_TRUE(IsFlexOp("FlexOp"));
+ EXPECT_FALSE(IsFlexOp("flex"));
+ EXPECT_FALSE(IsFlexOp("Fle"));
+ EXPECT_FALSE(IsFlexOp("OpFlex"));
+ EXPECT_FALSE(IsFlexOp(nullptr));
+ EXPECT_FALSE(IsFlexOp(""));
}
} // namespace
diff --git a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py
index 4ec539ab42..9c389144ff 100644
--- a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py
+++ b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py
@@ -61,7 +61,7 @@ def pairwise_distance_np(feature, squared=False):
class ContrastiveLossTest(test.TestCase):
def testContrastive(self):
- with self.test_session():
+ with self.cached_session():
num_data = 10
feat_dim = 6
margin = 1.0
@@ -90,7 +90,7 @@ class ContrastiveLossTest(test.TestCase):
class TripletSemiHardLossTest(test.TestCase):
def testTripletSemiHard(self):
- with self.test_session():
+ with self.cached_session():
num_data = 10
feat_dim = 6
margin = 1.0
@@ -146,7 +146,7 @@ class TripletSemiHardLossTest(test.TestCase):
class LiftedStructLossTest(test.TestCase):
def testLiftedStruct(self):
- with self.test_session():
+ with self.cached_session():
num_data = 10
feat_dim = 6
margin = 1.0
@@ -217,7 +217,7 @@ def convert_to_list_of_sparse_tensor(np_matrix):
class NpairsLossTest(test.TestCase):
def testNpairs(self):
- with self.test_session():
+ with self.cached_session():
num_data = 15
feat_dim = 6
num_classes = 5
@@ -261,7 +261,7 @@ class NpairsLossTest(test.TestCase):
class NpairsLossMultiLabelTest(test.TestCase):
def testNpairsMultiLabelLossWithSingleLabelEqualsNpairsLoss(self):
- with self.test_session():
+ with self.cached_session():
num_data = 15
feat_dim = 6
reg_lambda = 0.02
@@ -290,7 +290,7 @@ class NpairsLossMultiLabelTest(test.TestCase):
self.assertAllClose(loss_npairs, loss_npairs_multilabel)
def testNpairsMultiLabel(self):
- with self.test_session():
+ with self.cached_session():
num_data = 15
feat_dim = 6
num_classes = 10
@@ -527,7 +527,7 @@ class ClusterLossTest(test.TestCase):
def testClusteringLossPAMOff(self):
if not HAS_SKLEARN:
return
- with self.test_session():
+ with self.cached_session():
margin_multiplier = 10.0
embeddings, labels = self._genClusters(n_samples=128, n_clusters=64)
@@ -544,7 +544,7 @@ class ClusterLossTest(test.TestCase):
def testClusteringLossPAMOn(self):
if not HAS_SKLEARN:
return
- with self.test_session():
+ with self.cached_session():
margin_multiplier = 10.0
embeddings, labels = self._genClusters(n_samples=128, n_clusters=64)
diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
index 1d6d9a60e5..0d8df93d11 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
@@ -10,7 +10,6 @@ tensorflow/core/framework/graph.pb.cc
tensorflow/core/framework/graph_transfer_info.pb.cc
tensorflow/core/framework/kernel_def.pb.cc
tensorflow/core/framework/log_memory.pb.cc
-tensorflow/core/framework/model.pb.cc
tensorflow/core/framework/node_def.pb.cc
tensorflow/core/framework/op_def.pb.cc
tensorflow/core/framework/remote_fused_graph_execute_info.pb.cc
diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
index 884461ecae..d982df9319 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
@@ -10,7 +10,6 @@ tensorflow/core/framework/graph.pb.h
tensorflow/core/framework/graph_transfer_info.pb.h
tensorflow/core/framework/kernel_def.pb.h
tensorflow/core/framework/log_memory.pb.h
-tensorflow/core/framework/model.pb.h
tensorflow/core/framework/node_def.pb.h
tensorflow/core/framework/op_def.pb.h
tensorflow/core/framework/remote_fused_graph_execute_info.pb.h
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index 08de54b8e1..91af933cff 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -91,6 +91,8 @@ tensorflow/core/kernels/cwise_op_square.cc
tensorflow/core/kernels/cwise_op_squared_difference.cc
tensorflow/core/kernels/cwise_op_sub.cc
tensorflow/core/kernels/cwise_op_tanh.cc
+tensorflow/core/kernels/cwise_op_xdivy.cc
+tensorflow/core/kernels/cwise_op_xlogy.cc
tensorflow/core/kernels/cwise_ops_common.cc
tensorflow/core/kernels/data_format_ops.cc
tensorflow/core/kernels/decode_bmp_op.cc
@@ -253,6 +255,7 @@ tensorflow/core/kernels/strided_slice_op_inst_5.cc
tensorflow/core/kernels/strided_slice_op_inst_6.cc
tensorflow/core/kernels/strided_slice_op_inst_7.cc
tensorflow/core/kernels/string_join_op.cc
+tensorflow/core/kernels/string_util.cc
tensorflow/core/kernels/tensor_array.cc
tensorflow/core/kernels/tensor_array_ops.cc
tensorflow/core/kernels/tile_functor_cpu.cc
diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt
index e23f499214..f94d70db90 100644
--- a/tensorflow/contrib/makefile/tf_pb_text_files.txt
+++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt
@@ -10,7 +10,6 @@ tensorflow/core/framework/graph.pb_text.cc
tensorflow/core/framework/graph_transfer_info.pb_text.cc
tensorflow/core/framework/kernel_def.pb_text.cc
tensorflow/core/framework/log_memory.pb_text.cc
-tensorflow/core/framework/model.pb_text.cc
tensorflow/core/framework/node_def.pb_text.cc
tensorflow/core/framework/op_def.pb_text.cc
tensorflow/core/framework/remote_fused_graph_execute_info.pb_text.cc
diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt
index 5eae845d9b..8bec3e3e01 100644
--- a/tensorflow/contrib/makefile/tf_proto_files.txt
+++ b/tensorflow/contrib/makefile/tf_proto_files.txt
@@ -14,7 +14,6 @@ tensorflow/core/framework/graph.proto
tensorflow/core/framework/graph_transfer_info.proto
tensorflow/core/framework/kernel_def.proto
tensorflow/core/framework/log_memory.proto
-tensorflow/core/framework/model.proto
tensorflow/core/framework/node_def.proto
tensorflow/core/framework/op_def.proto
tensorflow/core/framework/reader_base.proto
diff --git a/tensorflow/contrib/metrics/python/kernel_tests/histogram_ops_test.py b/tensorflow/contrib/metrics/python/kernel_tests/histogram_ops_test.py
index 1d18d6beff..bed1ecb71c 100644
--- a/tensorflow/contrib/metrics/python/kernel_tests/histogram_ops_test.py
+++ b/tensorflow/contrib/metrics/python/kernel_tests/histogram_ops_test.py
@@ -31,21 +31,21 @@ class Strict1dCumsumTest(test.TestCase):
"""Test this private function."""
def test_empty_tensor_returns_empty(self):
- with self.test_session():
+ with self.cached_session():
tensor = constant_op.constant([])
result = histogram_ops._strict_1d_cumsum(tensor, 0)
expected = constant_op.constant([])
np.testing.assert_array_equal(expected.eval(), result.eval())
def test_length_1_tensor_works(self):
- with self.test_session():
+ with self.cached_session():
tensor = constant_op.constant([3], dtype=dtypes.float32)
result = histogram_ops._strict_1d_cumsum(tensor, 1)
expected = constant_op.constant([3], dtype=dtypes.float32)
np.testing.assert_array_equal(expected.eval(), result.eval())
def test_length_3_tensor_works(self):
- with self.test_session():
+ with self.cached_session():
tensor = constant_op.constant([1, 2, 3], dtype=dtypes.float32)
result = histogram_ops._strict_1d_cumsum(tensor, 3)
expected = constant_op.constant([1, 3, 6], dtype=dtypes.float32)
@@ -58,7 +58,7 @@ class AUCUsingHistogramTest(test.TestCase):
self.rng = np.random.RandomState(0)
def test_empty_labels_and_scores_gives_nan_auc(self):
- with self.test_session():
+ with self.cached_session():
labels = constant_op.constant([], shape=[0], dtype=dtypes.bool)
scores = constant_op.constant([], shape=[0], dtype=dtypes.float32)
score_range = [0, 1.]
@@ -155,7 +155,7 @@ class AUCUsingHistogramTest(test.TestCase):
from synthetic data.
"""
score_range = [0, 1.] or score_range
- with self.test_session():
+ with self.cached_session():
labels = array_ops.placeholder(dtypes.bool, shape=[num_records])
scores = array_ops.placeholder(dtypes.float32, shape=[num_records])
auc, update_op = histogram_ops.auc_using_histogram(
diff --git a/tensorflow/contrib/metrics/python/metrics/classification_test.py b/tensorflow/contrib/metrics/python/metrics/classification_test.py
index 3d0b81c1be..d6a670f97b 100644
--- a/tensorflow/contrib/metrics/python/metrics/classification_test.py
+++ b/tensorflow/contrib/metrics/python/metrics/classification_test.py
@@ -34,7 +34,7 @@ from tensorflow.python.platform import test
class ClassificationTest(test.TestCase):
def testAccuracy1D(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
pred = array_ops.placeholder(dtypes.int32, shape=[None])
labels = array_ops.placeholder(dtypes.int32, shape=[None])
acc = classification.accuracy(pred, labels)
@@ -44,7 +44,7 @@ class ClassificationTest(test.TestCase):
self.assertEqual(result, 0.5)
def testAccuracy1DBool(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
pred = array_ops.placeholder(dtypes.bool, shape=[None])
labels = array_ops.placeholder(dtypes.bool, shape=[None])
acc = classification.accuracy(pred, labels)
@@ -54,7 +54,7 @@ class ClassificationTest(test.TestCase):
self.assertEqual(result, 0.5)
def testAccuracy1DInt64(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
pred = array_ops.placeholder(dtypes.int64, shape=[None])
labels = array_ops.placeholder(dtypes.int64, shape=[None])
acc = classification.accuracy(pred, labels)
@@ -64,7 +64,7 @@ class ClassificationTest(test.TestCase):
self.assertEqual(result, 0.5)
def testAccuracy1DString(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
pred = array_ops.placeholder(dtypes.string, shape=[None])
labels = array_ops.placeholder(dtypes.string, shape=[None])
acc = classification.accuracy(pred, labels)
@@ -87,7 +87,7 @@ class ClassificationTest(test.TestCase):
classification.accuracy(pred, labels)
def testAccuracy1DWeighted(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
pred = array_ops.placeholder(dtypes.int32, shape=[None])
labels = array_ops.placeholder(dtypes.int32, shape=[None])
weights = array_ops.placeholder(dtypes.float32, shape=[None])
@@ -101,7 +101,7 @@ class ClassificationTest(test.TestCase):
self.assertEqual(result, 0.5)
def testAccuracy1DWeightedBroadcast(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
pred = array_ops.placeholder(dtypes.int32, shape=[None])
labels = array_ops.placeholder(dtypes.int32, shape=[None])
weights = array_ops.placeholder(dtypes.float32, shape=[])
@@ -161,7 +161,7 @@ class F1ScoreTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes.int64, seed=2)
f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -176,7 +176,7 @@ class F1ScoreTest(test.TestCase):
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes.float32)
labels = constant_op.constant(inputs)
f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3)
@@ -191,7 +191,7 @@ class F1ScoreTest(test.TestCase):
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run([f1_op])
# Threshold 0 will have around 0.5 precision and 1 recall yielding an F1
@@ -201,7 +201,7 @@ class F1ScoreTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(10000, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes.float32)
f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3)
@@ -214,7 +214,7 @@ class F1ScoreTest(test.TestCase):
self.assertAlmostEqual(2 * 0.5 * 1 / (1 + 0.5), f1.eval(), places=2)
def testWeights1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -228,7 +228,7 @@ class F1ScoreTest(test.TestCase):
self.assertAlmostEqual(1.0, f1.eval(), places=5)
def testWeights2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -242,7 +242,7 @@ class F1ScoreTest(test.TestCase):
self.assertAlmostEqual(1.0, f1.eval(), places=5)
def testZeroLabelsPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes.float32)
labels = array_ops.zeros([4])
f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3)
@@ -300,7 +300,7 @@ class F1ScoreTest(test.TestCase):
f1, f1_op = classification.f1_score(tf_labels, tf_predictions,
num_thresholds=3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for _ in range(num_batches):
sess.run([f1_op])
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 955b83b44d..fc64f343ab 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -2069,11 +2069,11 @@ class StreamingDynamicAUCTest(test.TestCase):
num_batches = 100
labels = np.array([])
predictions = np.array([])
- tf_labels = variables.Variable(
+ tf_labels = variables.VariableV1(
array_ops.ones(batch_size, dtypes_lib.int32),
collections=[ops.GraphKeys.LOCAL_VARIABLES],
dtype=dtypes_lib.int32)
- tf_predictions = variables.Variable(
+ tf_predictions = variables.VariableV1(
array_ops.ones(batch_size),
collections=[ops.GraphKeys.LOCAL_VARIABLES],
dtype=dtypes_lib.float32)
@@ -2133,15 +2133,15 @@ class StreamingDynamicAUCTest(test.TestCase):
labels = np.array([])
predictions = np.array([])
weights = np.array([])
- tf_labels = variables.Variable(
+ tf_labels = variables.VariableV1(
array_ops.ones(batch_size, dtypes_lib.int32),
collections=[ops.GraphKeys.LOCAL_VARIABLES],
dtype=dtypes_lib.int32)
- tf_predictions = variables.Variable(
+ tf_predictions = variables.VariableV1(
array_ops.ones(batch_size),
collections=[ops.GraphKeys.LOCAL_VARIABLES],
dtype=dtypes_lib.float32)
- tf_weights = variables.Variable(
+ tf_weights = variables.VariableV1(
array_ops.ones(batch_size),
collections=[ops.GraphKeys.LOCAL_VARIABLES],
dtype=dtypes_lib.float32)
@@ -2311,10 +2311,11 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
num_batches = 100
labels = np.array([])
predictions = np.array([])
- tf_labels = variables.Variable(array_ops.ones(batch_size, dtypes_lib.int32),
- collections=[ops.GraphKeys.LOCAL_VARIABLES],
- dtype=dtypes_lib.int32)
- tf_predictions = variables.Variable(
+ tf_labels = variables.VariableV1(
+ array_ops.ones(batch_size, dtypes_lib.int32),
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ dtype=dtypes_lib.int32)
+ tf_predictions = variables.VariableV1(
array_ops.ones(batch_size),
collections=[ops.GraphKeys.LOCAL_VARIABLES],
dtype=dtypes_lib.float32)
diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
index fcce52a07a..a5621b44cd 100644
--- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
+++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
@@ -66,10 +66,11 @@ class LossScaleOptimizer(optimizer.Optimizer):
# Choose a loss scale manager which decides how to pick the right loss scale
# throughout the training process.
- loss_scale_manger = tf.contrib.mixed_precision.FixedLossScaleManager(5000)
+ loss_scale_manager = tf.contrib.mixed_precision.FixedLossScaleManager(5000)
# Wraps the original optimizer in a LossScaleOptimizer.
- loss_scale_optimizer = LossScaleOptimizer(opt, loss_scale_manager)
+ loss_scale_optimizer =
+ tf.contrib.mixed_precision.LossScaleOptimizer(opt, loss_scale_manager)
# Call minimize() on the loss scale optimizer.
train_op = loss_scale_optimizer.minimize(loss)
diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py
index a81abac2fa..67e58ff15d 100644
--- a/tensorflow/contrib/model_pruning/python/pruning.py
+++ b/tensorflow/contrib/model_pruning/python/pruning.py
@@ -247,7 +247,8 @@ class Pruning(object):
# Stores the tensorflow sparsity variable.
# Built using self._setup_sparsity() or provided externally
- self._sparsity = sparsity if sparsity else self._setup_sparsity()
+ self._sparsity = (sparsity
+ if sparsity is not None else self._setup_sparsity())
# List of tensorflow assignments ops for new masks and thresholds
self._assign_ops = []
diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py
index cd3d8e76bb..1b6da5ce2b 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_test.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_test.py
@@ -45,7 +45,7 @@ class PruningHParamsTest(test.TestCase):
# Add global step variable to the graph
self.global_step = training_util.get_or_create_global_step()
# Add sparsity
- self.sparsity = variables.Variable(0.5, name="sparsity")
+ self.sparsity = variables.VariableV1(0.5, name="sparsity")
# Parse hparams
self.pruning_hparams = pruning.get_pruning_hparams().parse(
self.TEST_HPARAMS)
@@ -88,7 +88,7 @@ class PruningTest(test.TestCase):
width = 10
height = 20
with self.cached_session():
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.random_normal([width, height], stddev=1), name="weights")
masked_weights = pruning.apply_mask(weights,
variable_scope.get_variable_scope())
@@ -99,10 +99,10 @@ class PruningTest(test.TestCase):
def testUpdateSingleMask(self):
with self.cached_session() as session:
- weights = variables.Variable(
+ weights = variables.VariableV1(
math_ops.linspace(1.0, 100.0, 100), name="weights")
masked_weights = pruning.apply_mask(weights)
- sparsity = variables.Variable(0.5, name="sparsity")
+ sparsity = variables.VariableV1(0.5, name="sparsity")
p = pruning.Pruning(sparsity=sparsity)
p._spec.threshold_decay = 0.0
mask_update_op = p.mask_update_op()
@@ -115,8 +115,8 @@ class PruningTest(test.TestCase):
def _blockMasking(self, hparams, weights, expected_mask):
- threshold = variables.Variable(0.0, name="threshold")
- sparsity = variables.Variable(0.5, name="sparsity")
+ threshold = variables.VariableV1(0.0, name="threshold")
+ sparsity = variables.VariableV1(0.5, name="sparsity")
test_spec = ",".join(hparams)
pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
@@ -169,7 +169,7 @@ class PruningTest(test.TestCase):
partitioner = partitioned_variables.variable_axis_size_partitioner(40)
with self.cached_session() as session:
with variable_scope.variable_scope("", partitioner=partitioner):
- sparsity = variables.Variable(0.5, name="Sparsity")
+ sparsity = variables.VariableV1(0.5, name="Sparsity")
weights = variable_scope.get_variable(
"weights", initializer=math_ops.linspace(1.0, 100.0, 100))
masked_weights = pruning.apply_mask(
@@ -190,10 +190,10 @@ class PruningTest(test.TestCase):
]
test_spec = ",".join(param_list)
pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
- weights = variables.Variable(
+ weights = variables.VariableV1(
math_ops.linspace(1.0, 100.0, 100), name="weights")
masked_weights = pruning.apply_mask(weights)
- sparsity = variables.Variable(0.00, name="sparsity")
+ sparsity = variables.VariableV1(0.00, name="sparsity")
# Set up pruning
p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
p._spec.threshold_decay = 0.0
@@ -222,11 +222,11 @@ class PruningTest(test.TestCase):
pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
with variable_scope.variable_scope("layer1"):
- w1 = variables.Variable(
+ w1 = variables.VariableV1(
math_ops.linspace(1.0, 100.0, 100), name="weights")
_ = pruning.apply_mask(w1)
with variable_scope.variable_scope("layer2"):
- w2 = variables.Variable(
+ w2 = variables.VariableV1(
math_ops.linspace(1.0, 100.0, 100), name="weights")
_ = pruning.apply_mask(w2)
diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
index 6a7f5efecd..b9967fe76d 100644
--- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
+++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
@@ -136,8 +136,8 @@ void MPIRemoteRendezvous::RecvFromRemoteAsync(
MPIRendezvousMgr* mgr =
reinterpret_cast<MPIRendezvousMgr*>(this->rendezvous_mgr_);
- mgr->QueueRequest(parsed.FullKey().ToString(), step_id_,
- std::move(request_call), rendezvous_call);
+ mgr->QueueRequest(string(parsed.FullKey()), step_id_, std::move(request_call),
+ rendezvous_call);
}
MPIRemoteRendezvous::~MPIRemoteRendezvous() {}
@@ -258,7 +258,7 @@ void MPIRendezvousMgr::AddRequest(RecvTensorRequest request,
std::function<MPISendTensorCall*()> res = std::bind(
send_cb, status, send_args, recv_args, val, is_dead, mpi_send_call);
- SendQueueEntry req(parsed.FullKey().ToString().c_str(), std::move(res));
+ SendQueueEntry req(string(parsed.FullKey()), std::move(res));
this->QueueSendRequest(req);
diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
index 5596601ddb..90140fcab3 100644
--- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
+++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
@@ -71,7 +71,7 @@ class MPISendTensorCall {
void Init(const Rendezvous::ParsedKey& parsed, const int64 step_id,
const bool is_dead) {
- mRes_.set_key(parsed.FullKey().ToString());
+ mRes_.set_key(string(parsed.FullKey()));
mRes_.set_step_id(step_id);
mRes_.mutable_response()->set_is_dead(is_dead);
mRes_.mutable_response()->set_send_start_micros(
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index 2e4d61d931..f4ac70eb1a 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -16,6 +16,7 @@ py_library(
"__init__.py",
"python/training/adamax.py",
"python/training/addsign.py",
+ "python/training/agn_optimizer.py",
"python/training/drop_stale_gradient_optimizer.py",
"python/training/elastic_average_optimizer.py",
"python/training/external_optimizer.py",
@@ -246,6 +247,27 @@ tf_py_test(
)
tf_py_test(
+ name = "agn_optimizer_test",
+ srcs = ["python/training/agn_optimizer_test.py"],
+ additional_deps = [
+ ":opt_py",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:variables",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:training",
+ "//tensorflow/python:ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//third_party/py/numpy",
+ ],
+ tags = [
+ "notap", # this test launches a local server
+ ],
+)
+
+tf_py_test(
name = "elastic_average_optimizer_test",
srcs = ["python/training/elastic_average_optimizer_test.py"],
additional_deps = [
diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py
index ad7d7cfa6e..c7ea68efa9 100644
--- a/tensorflow/contrib/opt/__init__.py
+++ b/tensorflow/contrib/opt/__init__.py
@@ -1,4 +1,4 @@
- # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,6 +21,7 @@ from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.contrib.opt.python.training.adamax import *
from tensorflow.contrib.opt.python.training.addsign import *
+from tensorflow.contrib.opt.python.training.agn_optimizer import *
from tensorflow.contrib.opt.python.training.drop_stale_gradient_optimizer import *
from tensorflow.contrib.opt.python.training.elastic_average_optimizer import *
from tensorflow.contrib.opt.python.training.external_optimizer import *
@@ -60,6 +61,8 @@ _allowed_symbols = [
'VariableClippingOptimizer',
'MultitaskOptimizerWrapper',
'clip_gradients_by_global_norm',
+ 'AGNOptimizer',
+ 'AGNCustomGetter',
'ElasticAverageOptimizer',
'ElasticAverageCustomGetter',
'ModelAverageOptimizer',
diff --git a/tensorflow/contrib/opt/python/training/addsign_test.py b/tensorflow/contrib/opt/python/training/addsign_test.py
index 628a735e72..6150fa117f 100644
--- a/tensorflow/contrib/opt/python/training/addsign_test.py
+++ b/tensorflow/contrib/opt/python/training/addsign_test.py
@@ -80,9 +80,9 @@ class AddSignTest(test.TestCase):
global_step = resource_variable_ops.ResourceVariable(
0, trainable=False)
else:
- var0 = variables.Variable(var0_np)
- var1 = variables.Variable(var1_np)
- global_step = variables.Variable(
+ var0 = variables.VariableV1(var0_np)
+ var1 = variables.VariableV1(var1_np)
+ global_step = variables.VariableV1(
0, trainable=False)
grads0 = constant_op.constant(grads0_np)
grads1 = constant_op.constant(grads1_np)
@@ -183,9 +183,9 @@ class AddSignTest(test.TestCase):
global_step = resource_variable_ops.ResourceVariable(
0, trainable=False)
else:
- var0 = variables.Variable(var0_np)
- var1 = variables.Variable(var1_np)
- global_step = variables.Variable(
+ var0 = variables.VariableV1(var0_np)
+ var1 = variables.VariableV1(var1_np)
+ global_step = variables.VariableV1(
0, trainable=False)
grads0_np_indices = np.array([0, 1], dtype=np.int32)
grads0 = ops.IndexedSlices(
diff --git a/tensorflow/contrib/opt/python/training/agn_optimizer.py b/tensorflow/contrib/opt/python/training/agn_optimizer.py
new file mode 100644
index 0000000000..9d8bab8d33
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/agn_optimizer.py
@@ -0,0 +1,262 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import session_run_hook
+
+GLOBAL_VARIABLE_NAME = 'global_center_variable'
+GRAD_VARIABLE_NAME = 'grad_variable'
+
+
+class AGNCustomGetter(object):
+ """Custom_getter class is used to do:
+
+ 1. Change trainable variables to local collection and place them at worker
+ device
+ 2. Generate global variables(global center variables)
+ 3. Generate grad variables(gradients) which record the gradients sum
+ and place them at worker device
+ Notice that the class should be used with tf.replica_device_setter,
+ so that the global center variables and global step variable can be placed
+ at ps device.
+ """
+
+ def __init__(self, worker_device):
+ """
+ Args:
+ worker_device: put the grad_variables on worker device
+ """
+ self._worker_device = worker_device
+ self._global_map = {}
+ self._grad_map = {}
+
+ def __call__(self, getter, name, trainable, collections, *args, **kwargs):
+ if trainable:
+ with ops.device(self._worker_device):
+ local_var = getter(
+ name,
+ trainable=True,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ *args,
+ **kwargs)
+ if kwargs['reuse'] == True:
+ return local_var
+ global_center_variable = getter(
+ name='%s/%s' % (GLOBAL_VARIABLE_NAME, name),
+ trainable=False,
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES],
+ *args,
+ **kwargs)
+
+ with ops.device(self._worker_device):
+ grad_variable = getter(
+ name='%s/%s' % (GRAD_VARIABLE_NAME, name),
+ trainable=False,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ *args,
+ **kwargs)
+ if kwargs['partitioner'] is None:
+ self._grad_map[local_var] = grad_variable
+ self._global_map[local_var] = global_center_variable
+ else:
+ v_list = list(local_var)
+ for i in range(len(v_list)):
+ self._grad_map[v_list[i]] = list(grad_variable)[i]
+ self._global_map[v_list[i]] = list(global_center_variable)[i]
+ return local_var
+ else:
+ return getter(
+ name, trainable=trainable, collections=collections, *args, **kwargs)
+
+
+class AGNOptimizer(optimizer.Optimizer):
+ """Wrapper that implements the Accumulated GradientNormalization algorithm.
+
+ Reference:
+ Accumulated Gradient Normalization: Joeri Hermans ACML2017
+ https://arxiv.org/abs/1710.02368
+ """
+
+ def __init__(self,
+ optimizer,
+ num_worker,
+ custom_getter,
+ communication_period=10,
+ use_locking=True,
+ name='AGNOptimizer'):
+ """Construct a new AGN optimizer.
+
+ Args:
+ optimizer: input optimizer, can be sgd/momentum/adam etc.
+ num_worker: The number of workers
+ custom_getter: The AGNCustomGetter
+ communication_period: An int point value to controls the frequency of the
+ communication between every worker and the ps.
+ use_locking: If True use locks for update operations.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "AGNOptimizer".
+ """
+ super(AGNOptimizer, self).__init__(use_locking, name)
+ self._opt = optimizer
+ self._num_worker = num_worker
+ self._period = communication_period
+ self._global_map = custom_getter._global_map
+ self._grad_map = custom_getter._grad_map
+ self._local_step = variable_scope.get_variable(
+ initializer=0,
+ trainable=False,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ name='local_step')
+ self._opt._prepare()
+
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
+ """Apply gradients to global variables.
+
+ This is the second part of `minimize()`. It returns an `Operation` that
+ applies gradients.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs as returned by
+ `compute_gradients()`.
+ global_step: Optional `Variable` to increment by one after the variables
+ have been updated.
+ name: Optional name for the returned operation. Default to the name
+ passed to the `Optimizer` constructor.
+
+ Returns:
+ An `Operation` that applies the specified gradients. If `global_step`
+ was not None, that operation also increments `global_step`.
+ """
+ local_vars = [v for g, v in grads_and_vars if g is not None]
+ grads = [g for g, v in grads_and_vars if g is not None]
+
+ def _variable_creator(next_creator, collections, **kwargs):
+ if not collections:
+ collections = [ops.GraphKeys.LOCAL_VARIABLES]
+ elif ops.GraphKeys.GLOBAL_VARIABLES in collections:
+ collections = list(collections)
+ collections.append(ops.GraphKeys.LOCAL_VARIABLES)
+ collections.remove(ops.GraphKeys.GLOBAL_VARIABLES)
+ return next_creator(collections=collections, **kwargs)
+
+ # theta = theta - lr * grad
+ with variable_scope.variable_creator_scope(_variable_creator):
+ local_update_op = self._opt.apply_gradients(grads_and_vars)
+
+ # a = a + grad
+ update_ops = []
+ update_ops.append(local_update_op)
+ grad_vars = [self._grad_map[var] for var in local_vars]
+ for g, grad_var in zip(grads, grad_vars):
+ update_ops.append(state_ops.assign_add(grad_var, g))
+
+ global_center_vars = [self._global_map[var] for var in local_vars]
+
+ # update global variables.
+ def _Update_global_variables():
+ global_norm = []
+ # a = a / t
+ for g in grad_vars:
+ global_norm.append(state_ops.assign(g, g / self._period))
+ # apply
+ with ops.control_dependencies(global_norm):
+ apply_global_op = self._opt.apply_gradients(
+ zip(grad_vars, global_center_vars))
+
+ # pull
+ with ops.control_dependencies([apply_global_op]):
+ update_ops = []
+ if global_step:
+ with ops.colocate_with(global_step):
+ update_ops.append(state_ops.assign_add(global_step, 1))
+
+ for lvar in local_vars:
+ g_val = self._global_map[lvar].read_value()
+ update_ops.append(state_ops.assign(lvar, g_val))
+ for grad_var in grad_vars:
+ update_ops.append(
+ state_ops.assign(grad_var, array_ops.zeros_like(grad_var)))
+ variable_update = control_flow_ops.group(*(update_ops))
+ return variable_update
+
+ local_update = state_ops.assign_add(
+ self._local_step, 1, name='local_step_update').op
+
+ with ops.control_dependencies([local_update]):
+ condition = math_ops.equal(
+ math_ops.mod(self._local_step, self._period), 0)
+ with ops.control_dependencies(update_ops):
+ conditional_update = control_flow_ops.cond(
+ condition, _Update_global_variables, control_flow_ops.no_op)
+ return conditional_update
+
+ def get_init_op(self, task_index):
+ """Returns the op to let all the local variables and local center
+
+ variables equal to the global center variables before the training begins
+ """
+ init_ops = []
+ local_vars = variables.trainable_variables()
+ global_center_vars = [self._global_map[var] for var in local_vars]
+ grad_vars = [self._grad_map[var] for var in local_vars]
+ if not (local_vars and global_center_vars and grad_vars):
+ raise ValueError('The lists of local_variables, global_center_variables,'
+ 'grad_center_variables should not be empty')
+ for lvar, gc_var in zip(local_vars, global_center_vars):
+ init_ops.append(state_ops.assign(lvar, gc_var))
+ for g in grad_vars:
+ init_ops.append(state_ops.assign(g, array_ops.zeros_like(g)))
+ init_op = control_flow_ops.group(*(init_ops))
+ return init_op
+
+ def make_session_run_hook(self, is_chief, task_index):
+ """Creates a hook to handle AGNOptimizerHook ops such as initialization."""
+ return _AGNOptimizerHook(self, is_chief, task_index)
+
+
+class _AGNOptimizerHook(session_run_hook.SessionRunHook):
+
+ def __init__(self, agn_optimizer, is_chief, task_index):
+ """Creates hook to handle AGNOptimizer initialization ops.
+
+ Args:
+ agn_optimizer: `AGNOptimizer` which this hook will initialize.
+ is_chief: `Bool`, whether is this a chief replica or not.
+ task_index: int, task_index of worker
+ """
+ self._agn_optimizer = agn_optimizer
+ self._is_chief = is_chief
+ self._task_index = task_index
+
+ def begin(self):
+ self._local_init_op = variables.local_variables_initializer()
+ self._global_init_op = None
+ if self._is_chief:
+ self._global_init_op = variables.global_variables_initializer()
+ self._variable_init_op = self._agn_optimizer.get_init_op(self._task_index)
+
+ def after_create_session(self, session, coord):
+ """Run initialization ops"""
+ session.run(self._variable_init_op)
diff --git a/tensorflow/contrib/opt/python/training/agn_optimizer_test.py b/tensorflow/contrib/opt/python/training/agn_optimizer_test.py
new file mode 100644
index 0000000000..d3da290bdb
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/agn_optimizer_test.py
@@ -0,0 +1,281 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Tests for EAOptimizer."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import portpicker
+
+from tensorflow.contrib.opt.python.training import agn_optimizer
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import adam
+from tensorflow.python.training import device_setter
+from tensorflow.python.training import server_lib
+from tensorflow.python.training import training
+from tensorflow.python.training import training_util
+
+
+
+def create_local_cluster(num_workers, num_ps, protocol="grpc"):
+ """Create local GRPC servers and return them."""
+ worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
+ ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
+ cluster_dict = {
+ "worker": ["localhost:%s" % port for port in worker_ports],
+ "ps": ["localhost:%s" % port for port in ps_ports]
+ }
+ cs = server_lib.ClusterSpec(cluster_dict)
+
+ workers = [
+ server_lib.Server(
+ cs, job_name="worker", protocol=protocol, task_index=ix, start=True)
+ for ix in range(num_workers)
+ ]
+ ps_servers = [
+ server_lib.Server(
+ cs, job_name="ps", protocol=protocol, task_index=ix, start=True)
+ for ix in range(num_ps)
+ ]
+
+ return cluster_dict, workers, ps_servers
+
+
+# Creates the workers and return their sessions, graphs, train_ops.
+# Cheif worker will update at last
+def _get_workers(num_workers, period, workers, num_ps=1):
+ sessions = []
+ graphs = []
+ train_ops = []
+ for worker_id in range(num_workers):
+ graph = ops.Graph()
+ is_chief = (worker_id == 0)
+ with graph.as_default():
+ worker_device = "/job:worker/task:%d/cpu:0" % (worker_id)
+ ps_device = device_setter.replica_device_setter(
+ worker_device=worker_device,
+ ps_device="/job:ps/task:0/cpu:0",
+ ps_tasks=1)
+ agn_getter = agn_optimizer.AGNCustomGetter(worker_device=worker_device)
+ with variable_scope.variable_scope(
+ "", custom_getter=agn_getter), ops.device(ps_device):
+ global_step = training_util.get_or_create_global_step()
+ var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
+ var_1 = variable_scope.get_variable(initializer=0.5, name="v1")
+ if num_ps > 1:
+ with variable_scope.variable_scope(
+ "",
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_ps, axis=0),
+ custom_getter=agn_getter), ops.device(ps_device):
+
+ partition_var = variable_scope.get_variable(
+ "partition_var",
+ shape=[2, 4],
+ initializer=init_ops.zeros_initializer)
+ part_0 = list(partition_var)[0]
+ part_1 = list(partition_var)[1]
+
+ with ops.device("/job:worker/task:" + str(worker_id)):
+ grads_0 = constant_op.constant(-1.0)
+ grads_1 = constant_op.constant(-1.0)
+ grads_part_0 = constant_op.constant([[-1., -1., -1., -1.]])
+ grads_part_1 = constant_op.constant([[-1., -1., -1., -1.]])
+
+ optimizer = \
+ adam.AdamOptimizer(learning_rate=0.1, beta1=0.0, beta2=0.0)
+ opt = agn_optimizer.AGNOptimizer(
+ optimizer,
+ num_worker=num_workers,
+ communication_period=period,
+ custom_getter=agn_getter)
+ if num_ps == 1:
+ train_op = [
+ opt.apply_gradients(([grads_0, var_0], [grads_1, var_1]),
+ global_step)
+ ]
+ else:
+ train_op = [
+ opt.apply_gradients(
+ ([grads_0, var_0], [grads_1, var_1], [grads_part_0, part_0],
+ [grads_part_1, part_1]), global_step)
+ ]
+ hook = opt.make_session_run_hook(is_chief, worker_id)
+ # Creates MonitoredSession
+ sess = training.MonitoredTrainingSession(
+ workers[worker_id].target, hooks=[hook])
+
+ sessions.append(sess)
+ graphs.append(graph)
+ train_ops.append(train_op)
+
+ return sessions, graphs, train_ops
+
+
+class AGNOptimizerTest(test.TestCase):
+
+ def _run(self, train_op, sess):
+ sess.run(train_op)
+
+ def test1Workers2Period(self):
+ num_workers = 1
+ communication_period = 4
+ num_ps = 1
+ _, workers, _ = create_local_cluster(num_workers=num_workers, num_ps=num_ps)
+
+ sessions, graphs, train_ops = _get_workers(num_workers,
+ communication_period, workers)
+
+ var_0 = graphs[0].get_tensor_by_name("v0:0")
+ var_1 = graphs[0].get_tensor_by_name("v1:0")
+ global_step = training_util.get_global_step(graphs[0])
+ var_0_g = graphs[0].get_tensor_by_name(
+ agn_optimizer.GLOBAL_VARIABLE_NAME + "/v0:0")
+ var_1_g = graphs[0].get_tensor_by_name(
+ agn_optimizer.GLOBAL_VARIABLE_NAME + "/v1:0")
+
+ # verify adam/beta variables not in global collection
+ with graphs[0].as_default():
+ for ele in variables.global_variables():
+ self.assertTrue(ele.op.name.find("beta") < 0)
+ if ele.op.name.find("global_center_variable") < 0:
+ self.assertTrue(ele.op.name.find("Adam") < 0)
+
+ # Verify the initialized value.
+ self.assertAllEqual(0.0, sessions[0].run(var_0))
+ self.assertAllEqual(0.5, sessions[0].run(var_1))
+ self.assertAllEqual(0.0, sessions[0].run(var_0_g))
+ self.assertAllEqual(0.5, sessions[0].run(var_1_g))
+ self.assertAllEqual(0, sessions[0].run(global_step))
+ # step 0
+ sessions[0].run(train_ops[0])
+ self.assertNear(0.1, sessions[0].run(var_0), 1e-6)
+ self.assertNear(0.6, sessions[0].run(var_1), 1e-6)
+ self.assertAllEqual(0.0, sessions[0].run(var_0_g))
+ self.assertAllEqual(0.5, sessions[0].run(var_1_g))
+ self.assertAllEqual(0, sessions[0].run(global_step))
+
+ # 2 & 3
+ sessions[0].run(train_ops[0])
+ sessions[0].run(train_ops[0])
+ self.assertNear(0.3, sessions[0].run(var_0), 1e-6)
+ self.assertNear(0.8, sessions[0].run(var_1), 1e-6)
+
+ # 4
+ sessions[0].run(train_ops[0])
+ # pull
+ self.assertAllEqual(sessions[0].run(var_0), sessions[0].run(var_0_g))
+ self.assertAllEqual(sessions[0].run(var_1), sessions[0].run(var_1_g))
+ self.assertNear(0.1, sessions[0].run(var_0), 1e-6)
+ self.assertNear(0.6, sessions[0].run(var_1), 1e-6)
+
+ sessions[0].run(train_ops[0])
+ sessions[0].run(train_ops[0])
+ sessions[0].run(train_ops[0])
+ sessions[0].run(train_ops[0])
+ self.assertAllEqual(sessions[0].run(var_0), sessions[0].run(var_0_g))
+ self.assertAllEqual(sessions[0].run(var_1), sessions[0].run(var_1_g))
+ self.assertNear(0.2, sessions[0].run(var_0), 1e-6)
+ self.assertNear(0.7, sessions[0].run(var_1), 1e-6)
+
+ def test2Worker1Period(self):
+ num_workers = 2
+ communication_period = 1
+ num_ps = 2
+ _, workers, _ = create_local_cluster(num_workers=num_workers, num_ps=num_ps)
+
+ sessions, graphs, train_ops = _get_workers(
+ num_workers, communication_period, workers, num_ps=2)
+
+ var_0 = graphs[0].get_tensor_by_name("v0:0")
+ var_1 = graphs[0].get_tensor_by_name("v1:0")
+
+ var_0_1 = graphs[1].get_tensor_by_name("v0:0")
+ var_1_1 = graphs[1].get_tensor_by_name("v1:0")
+
+ var_0_g = graphs[0].get_tensor_by_name(
+ agn_optimizer.GLOBAL_VARIABLE_NAME + "/v0:0")
+ var_1_g = graphs[0].get_tensor_by_name(
+ agn_optimizer.GLOBAL_VARIABLE_NAME + "/v1:0")
+ part_0_g = graphs[0].get_tensor_by_name(
+ agn_optimizer.GLOBAL_VARIABLE_NAME +
+ "/partition_var/part_0:0")
+ part_1_g = graphs[0].get_tensor_by_name(
+ agn_optimizer.GLOBAL_VARIABLE_NAME +
+ "/partition_var/part_1:0")
+
+ # Verify the initialized value.
+ self.assertAllEqual(0.0, sessions[0].run(var_0))
+ self.assertAllEqual(0.5, sessions[0].run(var_1))
+ self.assertAllEqual(0.0, sessions[1].run(var_0_1))
+ self.assertAllEqual(0.5, sessions[1].run(var_1_1))
+ self.assertAllEqual(0.0, sessions[0].run(var_0_g))
+ self.assertAllEqual(0.5, sessions[0].run(var_1_g))
+
+ # verify each step
+ sessions[0].run(train_ops[0])
+ self.assertNear(0.1, sessions[0].run(var_0_g), 1e-6)
+ self.assertNDArrayNear([0.1, 0.1, 0.1, 0.1], sessions[0].run(part_0_g),
+ 1e-6)
+ self.assertNDArrayNear([0.1, 0.1, 0.1, 0.1], sessions[0].run(part_1_g),
+ 1e-6)
+
+ sessions[1].run(train_ops[1])
+ self.assertNear(0.2, sessions[0].run(var_0_g), 1e-6)
+ self.assertNDArrayNear([0.2, 0.2, 0.2, 0.2], sessions[0].run(part_0_g),
+ 1e-6)
+ self.assertNDArrayNear([0.2, 0.2, 0.2, 0.2], sessions[0].run(part_1_g),
+ 1e-6)
+
+ sessions[0].run(train_ops[0])
+ sessions[1].run(train_ops[1])
+
+ sessions[0].run(train_ops[0])
+ sessions[1].run(train_ops[1])
+ self.assertNear(0.6, sessions[0].run(var_0_g), 1e-6)
+ self.assertNDArrayNear([0.6, 0.6, 0.6, 0.6], sessions[0].run(part_0_g),
+ 1e-6)
+ self.assertNDArrayNear([0.6, 0.6, 0.6, 0.6], sessions[0].run(part_1_g),
+ 1e-6)
+
+ def testAGNCustomGetter(self):
+ cluster_spec = server_lib.ClusterSpec({
+ "ps": ["ps0:2222", "ps1:2222"],
+ "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
+ })
+ agn_getter = agn_optimizer.AGNCustomGetter(
+ worker_device="/job:worker/task:0")
+ with ops.device(
+ device_setter.replica_device_setter(cluster=cluster_spec,
+ worker_device="/job:worker/task:0",
+ ps_device="/job:ps")), \
+ variable_scope.variable_scope("", custom_getter=agn_getter):
+ v = variable_scope.get_variable(initializer=[1, 2], name="v")
+ w = variable_scope.get_variable(initializer=[2, 1], name="w")
+ v_g, w_g = agn_getter._global_map[v], agn_getter._global_map[w]
+ self.assertDeviceEqual("/job:worker/task:0", v.device)
+ self.assertDeviceEqual("job:ps/task:0", v_g.device)
+ self.assertDeviceEqual("/job:worker/task:0", w.device)
+ self.assertDeviceEqual("job:ps/task:1", w_g.device)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer_test.py b/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer_test.py
index 53232082e1..0a69096768 100644
--- a/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer_test.py
@@ -61,8 +61,8 @@ def _get_workers(num_workers, staleness):
graph = ops.Graph()
with graph.as_default():
global_step = training_util.create_global_step()
- var_0 = variables.Variable(0.0, name='v0')
- var_1 = variables.Variable(1.0, name='v1')
+ var_0 = variables.VariableV1(0.0, name='v0')
+ var_1 = variables.VariableV1(1.0, name='v1')
compute_gradients_queue = data_flow_ops.FIFOQueue(
-1, global_step.dtype.base_dtype, shapes=(),
name='compute_gradients_queue', shared_name='compute_gradients_queue')
diff --git a/tensorflow/contrib/opt/python/training/external_optimizer_test.py b/tensorflow/contrib/opt/python/training/external_optimizer_test.py
index 9997103016..70c5f8ff19 100644
--- a/tensorflow/contrib/opt/python/training/external_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/external_optimizer_test.py
@@ -69,9 +69,9 @@ class TestCase(test.TestCase):
class ExternalOptimizerInterfaceTest(TestCase):
def test_optimize(self):
- scalar = variables.Variable(random_ops.random_normal([]), 'scalar')
- vector = variables.Variable(random_ops.random_normal([2]), 'vector')
- matrix = variables.Variable(random_ops.random_normal([2, 3]), 'matrix')
+ scalar = variables.VariableV1(random_ops.random_normal([]), 'scalar')
+ vector = variables.VariableV1(random_ops.random_normal([2]), 'vector')
+ matrix = variables.VariableV1(random_ops.random_normal([2, 3]), 'matrix')
minimum_location = constant_op.constant(np.arange(9), dtype=dtypes.float32)
@@ -96,7 +96,7 @@ class ExternalOptimizerInterfaceTest(TestCase):
def test_callbacks(self):
vector_val = np.array([7., -2.], dtype=np.float32)
- vector = variables.Variable(vector_val, 'vector')
+ vector = variables.VariableV1(vector_val, 'vector')
minimum_location_val = np.arange(2)
minimum_location = constant_op.constant(
@@ -160,7 +160,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
rtol=1e-5,
atol=1e-5,
dimension=5):
- x = variables.Variable(array_ops.zeros(dimension))
+ x = variables.VariableV1(array_ops.zeros(dimension))
optimizer = external_optimizer.ScipyOptimizerInterface(
self._objective(x), method=method, options=options)
@@ -173,7 +173,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
def test_unconstrained(self):
dimension = 5
- x = variables.Variable(array_ops.zeros(dimension))
+ x = variables.VariableV1(array_ops.zeros(dimension))
optimizer = external_optimizer.ScipyOptimizerInterface(self._objective(x))
with self.cached_session() as sess:
@@ -230,7 +230,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
def test_nonlinear_programming(self):
vector_initial_value = [7., 7.]
- vector = variables.Variable(vector_initial_value, 'vector')
+ vector = variables.VariableV1(vector_initial_value, 'vector')
# Make norm as small as possible.
loss = math_ops.reduce_sum(math_ops.square(vector))
@@ -249,7 +249,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
def test_scalar_bounds(self):
vector_initial_value = [7., 7.]
- vector = variables.Variable(vector_initial_value, 'vector')
+ vector = variables.VariableV1(vector_initial_value, 'vector')
# Make norm as small as possible.
loss = math_ops.reduce_sum(math_ops.square(vector))
@@ -267,7 +267,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
def test_vector_bounds(self):
vector_initial_value = [7., 7.]
- vector = variables.Variable(vector_initial_value, 'vector')
+ vector = variables.VariableV1(vector_initial_value, 'vector')
# Make norm as small as possible.
loss = math_ops.reduce_sum(math_ops.square(vector))
@@ -287,7 +287,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
# after running optimizer.minimize().
# Bug reference: b/64065260
vector_initial_value = [7., 7.]
- vector = variables.Variable(vector_initial_value, 'vector')
+ vector = variables.VariableV1(vector_initial_value, 'vector')
loss = math_ops.reduce_sum(math_ops.square(vector))
optimizer = external_optimizer.ScipyOptimizerInterface(
@@ -301,7 +301,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
def test_callbacks(self):
vector_val = np.array([7., -2.], dtype=np.float32)
- vector = variables.Variable(vector_val, 'vector')
+ vector = variables.VariableV1(vector_val, 'vector')
minimum_location_val = np.arange(2)
minimum_location = constant_op.constant(
diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
index f08ffaa36f..089ecf597d 100644
--- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
@@ -236,7 +236,7 @@ class AdamOptimizerTest(test.TestCase, parameterized.TestCase):
opt.get_slot(var=var0, name="m").name)
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
self.doTestBasic(use_resource=False)
@test_util.run_in_graph_and_eager_modes(reset_test=True)
@@ -249,7 +249,7 @@ class AdamOptimizerTest(test.TestCase, parameterized.TestCase):
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -286,7 +286,7 @@ class AdamOptimizerTest(test.TestCase, parameterized.TestCase):
def testSharing(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
index b1fc50a21f..a25455e95d 100644
--- a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
@@ -110,10 +110,11 @@ def _get_workers(num_workers, steps, workers):
class ModelAverageOptimizerTest(test.TestCase):
+
def _run(self, train_op, sess):
sess.run(train_op)
- def test1Workers2Period(self):
+ def disabled_test1Workers2Period(self):
num_workers = 2
steps = 2
num_ps = 1
diff --git a/tensorflow/contrib/opt/python/training/powersign_test.py b/tensorflow/contrib/opt/python/training/powersign_test.py
index 0bcf5d230a..1cf9901dc0 100644
--- a/tensorflow/contrib/opt/python/training/powersign_test.py
+++ b/tensorflow/contrib/opt/python/training/powersign_test.py
@@ -81,9 +81,9 @@ class PowerSignTest(test.TestCase):
global_step = resource_variable_ops.ResourceVariable(
0, trainable=False)
else:
- var0 = variables.Variable(var0_np)
- var1 = variables.Variable(var1_np)
- global_step = variables.Variable(
+ var0 = variables.VariableV1(var0_np)
+ var1 = variables.VariableV1(var1_np)
+ global_step = variables.VariableV1(
0, trainable=False)
grads0 = constant_op.constant(grads0_np)
grads1 = constant_op.constant(grads1_np)
@@ -188,9 +188,9 @@ class PowerSignTest(test.TestCase):
global_step = resource_variable_ops.ResourceVariable(
0, trainable=False)
else:
- var0 = variables.Variable(var0_np)
- var1 = variables.Variable(var1_np)
- global_step = variables.Variable(
+ var0 = variables.VariableV1(var0_np)
+ var1 = variables.VariableV1(var1_np)
+ global_step = variables.VariableV1(
0, trainable=False)
grads0_np_indices = np.array([0, 1], dtype=np.int32)
grads0 = ops.IndexedSlices(
diff --git a/tensorflow/contrib/optimizer_v2/adagrad.py b/tensorflow/contrib/optimizer_v2/adagrad.py
index 25ec475499..dab1e02716 100644
--- a/tensorflow/contrib/optimizer_v2/adagrad.py
+++ b/tensorflow/contrib/optimizer_v2/adagrad.py
@@ -31,7 +31,7 @@ class AdagradOptimizer(optimizer_v2.OptimizerV2):
See this [paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
or this
- [intro](http://cs.stanford.edu/~ppasupat/a9online/uploads/proximal_notes.pdf).
+ [intro](https://ppasupat.github.io/a9online/uploads/proximal_notes.pdf).
"""
def __init__(self, learning_rate, initial_accumulator_value=0.1,
diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD
index 72ea777ca7..d50b52b8ff 100644
--- a/tensorflow/contrib/predictor/BUILD
+++ b/tensorflow/contrib/predictor/BUILD
@@ -27,7 +27,7 @@ py_library(
":contrib_estimator_predictor",
":core_estimator_predictor",
":saved_model_predictor",
- "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:estimator_py",
],
)
@@ -89,7 +89,6 @@ py_library(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
- "//tensorflow/python/estimator",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/saved_model:signature_constants",
],
diff --git a/tensorflow/contrib/quantization/README.md b/tensorflow/contrib/quantization/README.md
index 359950aaf3..826e8db2d3 100644
--- a/tensorflow/contrib/quantization/README.md
+++ b/tensorflow/contrib/quantization/README.md
@@ -2,6 +2,6 @@ The contrib/quantization package exposes a few TensorFlow quantization operation
If you are looking for quantized training rewrites that allow for training
quantized models that work with
-[TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/), you should look at
+[TensorFlow Lite](https://www.tensorflow.org/lite/), you should look at
the [contrib/quantize](https://www.tensorflow.org/api_docs/python/tf/contrib/quantize)
package.
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD
index c59f667f6a..23e3a25d71 100644
--- a/tensorflow/contrib/quantize/BUILD
+++ b/tensorflow/contrib/quantize/BUILD
@@ -20,9 +20,13 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":common",
+ "//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
"//tensorflow/python:variable_scope",
diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md
index 3f1e7d2792..0ab19c91bb 100644
--- a/tensorflow/contrib/quantize/README.md
+++ b/tensorflow/contrib/quantize/README.md
@@ -105,7 +105,7 @@ toco \
--std_value=127.5 --mean_value=127.5
```
-See the documentation for `tf.contrib.quantize` and [TensorFlow Lite](../mobile/tflite/).
+See the documentation for `tf.contrib.quantize` and [TensorFlow Lite](../lite/).
## Quantized accuracy results
diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py
index b27117dd48..e6c04bcf55 100644
--- a/tensorflow/contrib/quantize/python/common.py
+++ b/tensorflow/contrib/quantize/python/common.py
@@ -34,10 +34,10 @@ SKIPPED_PREFIXES = (
'ScalarSummary')
# Valid activation ops for quantization end points.
-_ACTIVATION_OP_SUFFIXES = ['/Relu6', '/Relu', '/Identity']
+_ACTIVATION_OP_SUFFIXES = ['Relu6', 'Relu', 'Identity']
# Regular expression for recognizing nodes that are part of batch norm group.
-_BATCHNORM_RE = re.compile(r'^(.*)/BatchNorm/batchnorm')
+_BATCHNORM_RE = re.compile(r'^(.*)BatchNorm/batchnorm')
def BatchNormGroups(graph):
diff --git a/tensorflow/contrib/quantize/python/common_test.py b/tensorflow/contrib/quantize/python/common_test.py
index 2b26302f8a..a3ce041cea 100644
--- a/tensorflow/contrib/quantize/python/common_test.py
+++ b/tensorflow/contrib/quantize/python/common_test.py
@@ -13,21 +13,26 @@
# limitations under the License.
# ==============================================================================
"""Tests for common utilities in this package."""
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-
+from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.quantize.python import common
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
+batch_norm = layers.batch_norm
+conv2d = layers.conv2d
+
class CommonTest(test_util.TensorFlowTestCase):
@@ -87,6 +92,56 @@ class CommonTest(test_util.TensorFlowTestCase):
for i in inputs:
self.assertIn(i, op.inputs)
+ def testBatchNormScope(self):
+ batch_size, height, width, depth = 5, 128, 128, 3
+ g = ops.Graph()
+ with g.as_default():
+ inputs = array_ops.zeros((batch_size, height, width, depth))
+ stride = 1
+ out_depth = 32
+ scope = ''
+ node = conv2d(
+ inputs,
+ out_depth, [2, 2],
+ stride=stride,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=None,
+ normalizer_fn=batch_norm,
+ normalizer_params=self._BatchNormParams(False),
+ scope=scope)
+
+ node = nn_ops.relu(node, name='Relu6')
+ bn_list = common.BatchNormGroups(g)
+ with open('/tmp/common_test.pbtxt', 'w') as f:
+ f.write(str(g.as_graph_def()))
+
+ # Exactly one batch norm layer with empty scope should be found
+ self.assertEqual(len(bn_list), 1)
+ self.assertEqual(bn_list[0], '')
+
+ def _BatchNormParams(self, fused=False, force_updates=False):
+ params = {
+ 'center': True,
+ 'scale': True,
+ 'decay': 1.0 - 0.003,
+ 'fused': fused
+ }
+ return params
+
+ def _WeightInit(self, stddev):
+ """Returns a truncated normal variable initializer.
+
+ Function is defined purely to shorten the name so that it stops wrapping.
+
+ Args:
+ stddev: Standard deviation of normal variable.
+
+ Returns:
+ An initializer that initializes with a truncated normal variable.
+ """
+ return init_ops.truncated_normal_initializer(stddev=stddev, seed=1234)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index 2971b28f45..7575b1b6cd 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -95,8 +95,7 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
_ComputeBatchNormCorrections(
context='',
match=match,
- freeze_batch_norm_delay=freeze_batch_norm_delay,
- fused_batch_norm=True))
+ freeze_batch_norm_delay=freeze_batch_norm_delay))
# The shape of depthwise weights is different, so we need to reshape the
# multiplier_tensor to ensure that the scaled_weight_tensor has the
# expected shape.
@@ -296,8 +295,7 @@ def _FindFusedBatchNorms(graph):
batch_to_space_op=batch_to_space_op)
-def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
- fused_batch_norm):
+def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay):
"""Computes batch norm correction params.
Before batch normalization is frozen:
@@ -327,14 +325,14 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
computation.
freeze_batch_norm_delay: Delay in steps at which computation switches
from regular batch norm to frozen mean and variance.
- fused_batch_norm: Bool, true if fused batch norm is used.
+
Returns:
A tuple of correction_scale, correction_recip, correction_offset
"""
g = ops.get_default_graph()
- prefix = '' if not context else context + '/'
+ prefix = '' if not context else context
with g.name_scope(prefix + 'batch_norm_correction'):
recip_sigma_mv = math_ops.rsqrt(
match.moving_variance_tensor + match.batch_epsilon)
@@ -420,10 +418,11 @@ def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor,
transpose_b=layer_op.get_attr('transpose_b'),
name=new_layer_name)
elif layer_op.type == 'DepthwiseConv2dNative':
+ # We don't copy dilation rate because we reuse the input SpaceToBatch
+ # and create our own BatchToSpace operation below.
conv = nn.depthwise_conv2d(
input_tensor,
weight_tensor,
- rate=layer_op.get_attr('dilations'),
strides=layer_op.get_attr('strides'),
padding=layer_op.get_attr('padding'),
name=new_layer_name)
@@ -495,8 +494,23 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
# Treat consumer ops in bypass modules differently since they have Add
# operations instead of Relu* above.
- add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
- add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
+ # Changes to make sure that the correct scope is selected for the bypass add
+ # The rule here is that if the scope is of the form: str1/str2 for the
+ # batch norm,
+ # the bypass add is at scope str1. If bn is of scope just str1, then the
+ # bypass add is at scope ''.
+ # If there is no batch norm, then there is no bypass add.
+ add_bypass_ctx = ''
+ if bn:
+ try:
+ add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
+ except AttributeError:
+ add_bypass_ctx = ''
+
+ if add_bypass_ctx:
+ add_bypass_ctx = add_bypass_ctx + '/'
+
+ add_bypass = graph.get_operation_by_name(add_bypass_ctx + 'Add')
nodes_modified_count = common.RerouteTensor(
folded_op.outputs[0], original_op.outputs[0], can_modify=[add_bypass])
if nodes_modified_count != 1:
@@ -505,8 +519,8 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
def _IsValidUnfusedBatchNorm(graph, context):
"""Checks that the output of the unfused batch norm has consumers."""
- add_shift = graph.get_operation_by_name(
- context + '/BatchNorm/batchnorm_1/add_1')
+ add_shift = graph.get_operation_by_name(context +
+ 'BatchNorm/batchnorm_1/add_1')
# Ensure that the output tensor of batch norm has consumers, otherwise this
# is a dangling node and not a match.
return bool(add_shift.outputs[0].consumers())
@@ -538,7 +552,8 @@ def _FindMatchingTensor(graph, match_pattern, scope):
if op.name.endswith(match_pattern):
split_name = op.name.split('/')
num_matches = len(set(split_name) & split_context)
- if num_matches > 0:
+
+ if num_matches > 0 or not scope:
match_dict[op.name] = num_matches
# match_dict contains matching op names from graph with values being
# number of matches to scope. We pick the key with the most matches
@@ -597,21 +612,21 @@ def _GetBatchNormParams(graph, context, has_scaling):
# op.name = MobilenetV2/expanded_conv_3/depthwise/BatchNorm/moving_mean/read
# will have 2 matches,scope with a different conv layer will have one match.
- op_suffix_mean = '/BatchNorm/moments/Squeeze'
- op_suffix_variance = '/BatchNorm/moments/Squeeze_1'
- op_suffix_epsilon = '/BatchNorm/batchnorm_1/add/y'
- op_suffix_bn_decay_mean = '/BatchNorm/AssignMovingAvg/decay'
- op_suffix_bn_decay_var = '/BatchNorm/AssignMovingAvg_1/decay'
+ op_suffix_mean = 'BatchNorm/moments/Squeeze'
+ op_suffix_variance = 'BatchNorm/moments/Squeeze_1'
+ op_suffix_epsilon = 'BatchNorm/batchnorm_1/add/y'
+ op_suffix_bn_decay_mean = 'BatchNorm/AssignMovingAvg/decay'
+ op_suffix_bn_decay_var = 'BatchNorm/AssignMovingAvg_1/decay'
if variable_scope.get_variable_scope().use_resource:
- op_suffix_gamma = '/BatchNorm/gamma/Read/ReadVariableOp'
+ op_suffix_gamma = 'BatchNorm/gamma/Read/ReadVariableOp'
op_suffix_moving_variance = (
- '/BatchNorm/moving_variance/Read/ReadVariableOp')
- op_suffix_moving_mean = ('/BatchNorm/moving_mean/Read/ReadVariableOp')
+ 'BatchNorm/moving_variance/Read/ReadVariableOp')
+ op_suffix_moving_mean = ('BatchNorm/moving_mean/Read/ReadVariableOp')
else:
- op_suffix_gamma = '/BatchNorm/gamma'
- op_suffix_moving_variance = '/BatchNorm/moving_variance/read'
- op_suffix_moving_mean = '/BatchNorm/moving_mean/read'
+ op_suffix_gamma = 'BatchNorm/gamma'
+ op_suffix_moving_variance = 'BatchNorm/moving_variance/read'
+ op_suffix_moving_mean = 'BatchNorm/moving_mean/read'
# Parse through list of ops to find relevant ops
batch_mean_tensor = _FindMatchingTensor(graph, op_suffix_mean, context)
@@ -679,8 +694,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
the folded graph (add_fold).
"""
mul_scale_name = 'mul_1' if has_scaling else 'mul'
- mul_scale = graph.get_operation_by_name(context +
- '/BatchNorm/batchnorm_1/' +
+ mul_scale = graph.get_operation_by_name(context + 'BatchNorm/batchnorm_1/' +
mul_scale_name)
op_below = mul_scale.inputs[0].op
# Skip over the BatchToSpace operation in the case of atrous convolutions.
@@ -697,8 +711,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
_ComputeBatchNormCorrections(
context=context,
match=match,
- freeze_batch_norm_delay=freeze_batch_norm_delay,
- fused_batch_norm=False))
+ freeze_batch_norm_delay=freeze_batch_norm_delay))
# Special handling for weights of depthwise convolution.
if op_below.type == 'DepthwiseConv2dNative':
new_shape = [
@@ -706,27 +719,27 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
weights.get_shape().as_list()[3]
]
scale_name = 'mul' if has_scaling else 'Rsqrt'
- scale = graph.get_operation_by_name(
- context + '/BatchNorm/batchnorm_1/' + scale_name)
+ scale = graph.get_operation_by_name(context + 'BatchNorm/batchnorm_1/' +
+ scale_name)
scale = array_ops.reshape(scale.outputs[0], new_shape,
- context + '/scale_reshape')
+ context + 'scale_reshape')
if correction_scale is not None:
correction_scale = array_ops.reshape(correction_scale, new_shape,
- context + '/correction_reshape')
+ context + 'correction_reshape')
with ops.device(mul_scale.device):
weights = math_ops.multiply(correction_scale, weights,
- context + '/correction_mult')
+ context + 'correction_mult')
- mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights),
- (1, scale)])
+ mul_fold = _CloneOp(mul_scale, context + 'mul_fold', [(0, weights),
+ (1, scale)])
elif op_below.type in ['Conv2D', 'MatMul']:
if correction_scale is not None:
with ops.device(mul_scale.device):
weights = math_ops.multiply(correction_scale, weights,
- context + '/correction_mult')
- mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights)])
+ context + 'correction_mult')
+ mul_fold = _CloneOp(mul_scale, context + 'mul_fold', [(0, weights)])
else:
raise ValueError('Cannot handle operation of type: %s' % op_below.type)
_AssertShapesMatch('mul_fold', mul_fold.inputs[0], mul_fold.outputs[0])
@@ -734,8 +747,8 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
conv_or_fc_folded = _CloneOp(op_below, op_below.name + '_Fold',
[(1, mul_fold.outputs[0])])
- add_shift = graph.get_operation_by_name(
- context + '/BatchNorm/batchnorm_1/add_1')
+ add_shift = graph.get_operation_by_name(context +
+ 'BatchNorm/batchnorm_1/add_1')
corrected_output = conv_or_fc_folded.outputs[0]
# Copy the batch to space operation if we have a atrous convolution.
@@ -748,10 +761,10 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
if correction_offset is not None:
with ops.device(conv_or_fc_folded.device):
corrected_output = math_ops.multiply(correction_recip, corrected_output,
- context + '/post_conv_mul')
+ context + 'post_conv_mul')
corrected_output = math_ops.add(corrected_output, (correction_offset),
- context + '/correction_add')
- add_fold = _CloneOp(add_shift, context + '/add_fold', [(0, corrected_output)])
+ context + 'correction_add')
+ add_fold = _CloneOp(add_shift, context + 'add_fold', [(0, corrected_output)])
_AssertShapesMatch('add_fold', add_fold.inputs[0], add_fold.outputs[0])
return add_shift, add_fold
@@ -930,7 +943,7 @@ def _HasScaling(graph, input_to_ops_map, bn):
Returns:
A boolean indicating whether this batch norm layer has scaling enabled.
"""
- rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm_1/Rsqrt')
+ rsqrt_op = graph.get_operation_by_name(bn + 'BatchNorm/batchnorm_1/Rsqrt')
rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op)
return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index e88db0acd5..afb9de8370 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -97,8 +97,11 @@ def Quantize(graph,
layer_match.activation_op)
add_context = context
if layer_match.bypass_op:
- add_context = re.search(r'^(.*)/([^/]+)', context).group(1)
-
+ pattern_match_result = re.search(r'^(.*)/([^/]+)', context)
+ if pattern_match_result is not None:
+ add_context = pattern_match_result.group(1)
+ else:
+ add_context = ''
# If `scope` is given, only quantize it if the producer of weights
# (usually it's the layer op) is in the right scope.
_InsertQuantOp(
@@ -156,8 +159,12 @@ def Quantize(graph,
# Quantize bypass ops that occur after the activation.
if layer_match.post_activation_bypass_op is not None:
- post_activation_bypass_context = re.search(
- r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name).group(1)
+ pattern_match_result = re.search(
+ r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name)
+ if pattern_match_result is not None:
+ post_activation_bypass_context = pattern_match_result.group(1)
+ else:
+ post_activation_bypass_context = ''
# If `scope` is given, only quantize it if the producer is in the right
# scope.
# Make sure the op following this isn't an activation. In which case, we
@@ -454,8 +461,8 @@ class _LayerMatch(object):
return self._bias_add_op
-def _FollowedByFakeQuant(tensor):
- """Returns True if the tensor is followed by a FakeQuant."""
+def _GetFollowingFakeQuantOp(tensor):
+ """Returns the following FakeQuant op if it exists else None."""
fake_quant_ops = set([
'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs',
'FakeQuantWithMinMaxVarsPerChannel'
@@ -465,11 +472,11 @@ def _FollowedByFakeQuant(tensor):
while consumers:
c = consumers.pop()
if c.type in fake_quant_ops:
- return True
+ return c
elif c.type in pass_through_ops:
for output in c.outputs:
consumers.extend(output.consumers())
- return False
+ return None
def _InsertQuantOp(context,
@@ -552,44 +559,77 @@ def _InsertQuantOp(context,
# Prevent ops from being quantized multiple times. Bypass ops can sometimes
# overlap between multiple matches, so we need to ensure that we don't
# add duplicate FakeQuant operations.
- if _FollowedByFakeQuant(inputs):
- return
-
- if moving_avg:
- quant = (
- quant_ops.MovingAvgQuantize(
- inputs,
- init_min=init_min,
- init_max=init_max,
- ema_decay=ema_decay,
- is_training=is_training,
- num_bits=bits,
- narrow_range=narrow_range,
- vars_collection=vars_collection,
- name_prefix=name_prefix))
+ fake_quant_op = _GetFollowingFakeQuantOp(inputs)
+
+ # If we find that we are attempting to insert a fake quant op following
+ # a fake quant, we skip inserting a fake quant op
+
+ if fake_quant_op is None:
+ if moving_avg:
+ quant = (
+ quant_ops.MovingAvgQuantize(
+ inputs,
+ init_min=init_min,
+ init_max=init_max,
+ ema_decay=ema_decay,
+ is_training=is_training,
+ num_bits=bits,
+ narrow_range=narrow_range,
+ vars_collection=vars_collection,
+ name_prefix=name_prefix))
+ else:
+ quant = (
+ quant_ops.LastValueQuantize(
+ inputs,
+ init_min=init_min,
+ init_max=init_max,
+ is_training=is_training,
+ num_bits=bits,
+ narrow_range=narrow_range,
+ vars_collection=vars_collection,
+ name_prefix=name_prefix))
+
+ if quant_delay and quant_delay > 0:
+ activate_quant = math_ops.greater_equal(
+ common.CreateOrGetQuantizationStep(),
+ quant_delay,
+ name=name_prefix + '/activate_quant')
+ quant = control_flow_ops.cond(
+ activate_quant,
+ lambda: quant,
+ lambda: inputs,
+ name=name_prefix + '/delayed_quant')
else:
- quant = (
- quant_ops.LastValueQuantize(
- inputs,
- init_min=init_min,
- init_max=init_max,
- is_training=is_training,
- num_bits=bits,
- narrow_range=narrow_range,
- vars_collection=vars_collection,
- name_prefix=name_prefix))
-
- if quant_delay and quant_delay > 0:
- activate_quant = math_ops.greater_equal(
- common.CreateOrGetQuantizationStep(),
- quant_delay,
- name=name_prefix + '/activate_quant')
- quant = control_flow_ops.cond(
- activate_quant,
- lambda: quant,
- lambda: inputs,
- name=name_prefix + '/delayed_quant')
-
+ # If a fake quant op is present already, make sure that
+ # any downstream use of the tensor reroutes to the appropriate quantized
+ # tensor. If there is no quant_delay, this is simply the output of the
+ # fake quant op. If there is a quant delay, we reroute to the output
+ # of the delayed quant operation, which inserts quantization only after
+ # a specified quant_delay
+
+ quant = fake_quant_op.outputs[0]
+ if quant_delay and quant_delay > 0:
+ name_prefix = '/'.join(quant.name.split('/')[:-1])
+ quant = quant.graph.get_tensor_by_name(name_prefix +
+ '/delayed_quant/Merge:0')
+ pruned_consumer_set = set()
+ for consumer in consumers:
+ fake_quant_dest_op = _GetFollowingFakeQuantOp(consumer.outputs[0])
+ if (fake_quant_dest_op is None or
+ fake_quant_dest_op.name != fake_quant_op.name):
+ pruned_consumer_set.add(consumer)
+ consumers = pruned_consumer_set
+
+ # If we have
+ # input->pass_through->fake_quant
+ # there is nothing to reroute.
+ #
+ # If we have
+ # input-> pass_through->fake_quant
+ # |-> consumer
+ # Then we reroute such that:
+ # input-> pass_through->fake_quant
+ # |-> consumer
if consumers:
tensors_modified_count = common.RerouteTensor(
quant, inputs, can_modify=consumers)
diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py
index e80d2183a6..a9fc6c3c61 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import template
from tensorflow.python.platform import googletest
@@ -306,6 +307,42 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
# No ops should be inserted or removed.
self.assertEqual(op_names_before_rewrite, op_names_after_rewrite)
+ def testWithSharedWeights(self):
+
+ self._RunTestOverAllRewrites(self._TestWithSharedWeights)
+ self._RunTestOverTrainingRewrites(self._TestRewriteWithSharedWeights)
+
+ def _TestRewriteWithSharedWeights(self, rewrite_fn, quant_delay=1):
+ self._TestWithSharedWeights(rewrite_fn, quant_delay)
+
+ def _TestWithSharedWeights(self, rewrite_fn, quant_delay=None):
+ with ops.Graph().as_default() as g:
+ conv = template.make_template('shared_weights_conv', self._ConvLayer)
+ conv()
+ conv()
+ if quant_delay is None:
+ rewrite_fn()
+ else:
+ rewrite_fn(quant_delay=quant_delay)
+
+ conv_ops = [op for op in g.get_operations() if op.type == 'Conv2D']
+ weights_quants = [
+ op for op in g.get_operations()
+ if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars'
+ ]
+ # Check that the shared weights variable is not quantized multiple times
+ self.assertTrue(len(weights_quants) == 1)
+ weights_quant_tensor = weights_quants[0].outputs[0]
+ if quant_delay:
+ delayed_weights_quants = [
+ op for op in g.get_operations()
+ if 'weights_quant' in op.name and op.type == 'Merge'
+ ]
+ self.assertTrue(len(delayed_weights_quants) == 1)
+ weights_quant_tensor = delayed_weights_quants[0].outputs[0]
+ # Check that the Conv2D operations get the quantized weights
+ self.assertTrue(all(weights_quant_tensor in op.inputs for op in conv_ops))
+
def _ConvLayer(
self, input_tensor=None, scope='test', pre_activation_bypass=False,
post_activation_bypass=False):
diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
index 31a2955ddb..f6bf57a789 100644
--- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
@@ -58,85 +58,102 @@ class QuantizeTest(test_util.TensorFlowTestCase):
]
for params in parameters_list:
# Test everything with resource variables and normal variables.
- test_fn(params[0], params[1], params[2], params[3], False)
- test_fn(params[0], params[1], params[2], params[3], True)
+ test_fn(params[0], params[1], params[2], params[3], False, None)
+ test_fn(params[0], params[1], params[2], params[3], True, None)
+ # Test with both empty scope and an example scope
+ test_fn(params[0], params[1], params[2], params[3], False, 'test')
+ test_fn(params[0], params[1], params[2], params[3], True, 'test')
def _AssertCorrectQuantizedGraphWithoutBatchNorm(
self, graph, scope, layer, activation_op_name, with_bypass, delay,
use_resource):
quantization_node_name = 'FakeQuantWithMinMaxVars'
- weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
- quantization_node_name)
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ delim = '/' if conv_scope else ''
+
+ if scope:
+ scope = scope + '/'
+ weights_quant = graph.get_operation_by_name(
+ conv_scope + delim + 'weights_quant/' + quantization_node_name)
self.assertEqual(weights_quant.type, quantization_node_name)
# Assemble the expected inputs.
if use_resource:
expected_inputs = [
- scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ conv_scope + delim +
+ 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ conv_scope + delim +
+ 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
]
if layer == 'DepthwiseConv2dNative':
- expected_inputs.append(scope + '/depthwise/ReadVariableOp')
+ expected_inputs.append(conv_scope + delim + 'depthwise/ReadVariableOp')
else:
- expected_inputs.append(scope + '/' + layer + '/ReadVariableOp')
+ expected_inputs.append(conv_scope + delim + layer + '/ReadVariableOp')
else:
expected_inputs = [
- scope + '/weights_quant/AssignMinLast',
- scope + '/weights_quant/AssignMaxLast',
+ conv_scope + delim + 'weights_quant/AssignMinLast',
+ conv_scope + delim + 'weights_quant/AssignMaxLast',
]
if layer == 'DepthwiseConv2dNative':
- expected_inputs.append(scope + '/depthwise_weights/read')
+ expected_inputs.append(conv_scope + delim + 'depthwise_weights/read')
else:
- expected_inputs.append(scope + '/weights/read')
+ expected_inputs.append(conv_scope + delim + 'weights/read')
self._AssertInputOpsAre(weights_quant, expected_inputs)
if delay and delay > 0:
- output_op_name = scope + '/weights_quant/delayed_quant/Switch_1'
+ output_op_name = (
+ conv_scope + delim + 'weights_quant/delayed_quant/Switch_1')
else:
if layer == 'DepthwiseConv2dNative':
- output_op_name = scope + '/depthwise'
+ output_op_name = conv_scope + delim + 'depthwise'
else:
- output_op_name = scope + '/' + layer
+ output_op_name = conv_scope + delim + layer
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
if with_bypass:
- conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' +
- quantization_node_name)
+ conv_quant = graph.get_operation_by_name(
+ conv_scope + delim + 'conv_quant/' + quantization_node_name)
self.assertEqual(conv_quant.type, quantization_node_name)
if use_resource:
expected_inputs = [
- scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
- scope + '/BiasAdd',
+ conv_scope + delim +
+ 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ conv_scope + delim +
+ 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ conv_scope + delim + 'BiasAdd',
]
else:
expected_inputs = [
- scope + '/conv_quant/AssignMinEma',
- scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd'
+ conv_scope + delim + 'conv_quant/AssignMinEma',
+ conv_scope + delim + 'conv_quant/AssignMaxEma',
+ conv_scope + delim + 'BiasAdd'
]
self._AssertInputOpsAre(conv_quant, expected_inputs)
- output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
- if delay else 'test/Add')
+
+ output_op_name = (
+ conv_scope + delim + 'conv_quant/delayed_quant/Switch_1'
+ if delay else scope + 'Add')
self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name])
- act_quant = graph.get_operation_by_name('test/act_quant/' +
+ act_quant = graph.get_operation_by_name(scope + 'act_quant/' +
quantization_node_name)
self.assertEqual(act_quant.type, quantization_node_name)
if use_resource:
expected_inputs = [
- 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
- 'test/' + activation_op_name,
+ scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ scope + activation_op_name,
]
else:
expected_inputs = [
- 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
- 'test/' + activation_op_name
+ scope + 'act_quant/AssignMinEma', scope + 'act_quant/AssignMaxEma',
+ scope + activation_op_name
]
self._AssertInputOpsAre(act_quant, expected_inputs)
- output_op_name = ('test/act_quant/delayed_quant/Switch_1'
- if delay else 'control_dependency')
+ output_op_name = (
+ scope + 'act_quant/delayed_quant/Switch_1'
+ if delay else 'control_dependency')
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
self._AssertIdempotent(graph)
@@ -145,7 +162,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._TestQuantize_Conv2dWithoutBatchNorm)
def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name,
- with_bypass, delay, use_resource):
+ with_bypass, delay, use_resource,
+ scope):
"""Tests quantization: inputs -> Conv2d no batch norm -> Activation.
Args:
@@ -156,6 +174,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -165,7 +184,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
stride = 1 if with_bypass else 2
out_depth = 3 if with_bypass else 32
activation_fn = None if with_bypass else activation
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
node = conv2d(
inputs,
out_depth, [5, 5],
@@ -173,16 +194,19 @@ class QuantizeTest(test_util.TensorFlowTestCase):
padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
- scope=scope)
+ scope=conv_scope)
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
quantize.Quantize(graph, True, quant_delay=delay)
+ if conv_scope is None:
+ conv_scope = ''
+
self._AssertCorrectQuantizedGraphWithoutBatchNorm(
graph, scope, 'Conv2D', activation_op_name, with_bypass, delay,
use_resource)
@@ -192,7 +216,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._TestQuantize_FCWithoutBatchNorm)
def _TestQuantize_FCWithoutBatchNorm(self, activation, activation_op_name,
- with_bypass, delay, use_resource):
+ with_bypass, delay, use_resource, scope):
"""Tests quantization: inputs -> FC no batch norm -> Activation.
Args:
@@ -203,6 +227,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -211,16 +236,18 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs = array_ops.zeros((batch_size, depth))
out_depth = 256 if with_bypass else 128
activation_fn = None if with_bypass else activation
- scope = 'test/test2' if with_bypass else 'test'
+ fc_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
node = fully_connected(
inputs,
out_depth,
weights_initializer=self._WeightInit(0.03),
activation_fn=activation_fn,
- scope=scope)
+ scope=fc_scope)
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
@@ -235,7 +262,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._TestQuantize_DepthwiseConv2dWithoutBatchNorm)
def _TestQuantize_DepthwiseConv2dWithoutBatchNorm(
- self, activation, activation_op_name, with_bypass, delay, use_resource):
+ self, activation, activation_op_name, with_bypass, delay, use_resource,
+ scope):
"""Tests quantization: inputs -> DWConv2d no batch norm -> Activation.
Args:
@@ -246,6 +274,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -254,7 +283,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs = array_ops.zeros((batch_size, height, width, depth))
stride = 1 if with_bypass else 2
activation_fn = None if with_bypass else activation
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
+
node = separable_conv2d(
inputs,
None, [5, 5],
@@ -263,10 +295,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
- scope=scope)
+ scope=conv_scope)
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
@@ -280,8 +312,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._RunWithoutBatchNormTestOverParameters(
self._TestQuantize_AtrousConvWithoutBatchNorm)
- def _TestQuantize_AtrousConvWithoutBatchNorm(
- self, activation, activation_op_name, with_bypass, delay, use_resource):
+ def _TestQuantize_AtrousConvWithoutBatchNorm(self, activation,
+ activation_op_name, with_bypass,
+ delay, use_resource, scope):
"""Tests quantization: inputs -> atrous conv no batch norm -> Activation.
Args:
@@ -292,6 +325,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -300,7 +334,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs = array_ops.zeros((batch_size, height, width, depth))
dilation_rate = 2
activation_fn = None if with_bypass else activation
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
+
node = separable_conv2d(
inputs,
None, [3, 3],
@@ -309,10 +346,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
- scope=scope)
+ scope=conv_scope)
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
@@ -353,78 +390,96 @@ class QuantizeTest(test_util.TensorFlowTestCase):
]
for params in parameters_list:
# Test everything with resource variables and normal variables.
- test_fn(params[0], params[1], params[2], params[3], params[4], False)
- test_fn(params[0], params[1], params[2], params[3], params[4], True)
+ test_fn(params[0], params[1], params[2], params[3], params[4], False,
+ None)
+ test_fn(params[0], params[1], params[2], params[3], params[4], True, None)
+ test_fn(params[0], params[1], params[2], params[3], params[4], False,
+ 'test')
+ test_fn(params[0], params[1], params[2], params[3], params[4], True,
+ 'test')
def _AssertCorrectQuantizedGraphWithBatchNorm(self, graph, scope, layer,
activation_op_name, with_bypass,
delay, use_resource):
quantization_node_name = 'FakeQuantWithMinMaxVars'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ delim = '/' if conv_scope else ''
+
+ if scope:
+ scope = scope + '/'
+
weights_quant = graph.get_operation_by_name(
- scope + '/weights_quant/' + quantization_node_name)
+ conv_scope + delim + 'weights_quant/' + quantization_node_name)
+
self.assertEqual(weights_quant.type, quantization_node_name)
if use_resource:
expected_inputs = [
- scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ conv_scope + delim +
+ 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ conv_scope + delim +
+ 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
]
else:
expected_inputs = [
- scope + '/weights_quant/' + 'AssignMinLast',
- scope + '/weights_quant/' + 'AssignMaxLast'
+ conv_scope + delim + 'weights_quant/' + 'AssignMinLast',
+ conv_scope + delim + 'weights_quant/' + 'AssignMaxLast'
]
- expected_inputs.append(scope + '/mul_fold')
+ expected_inputs.append(conv_scope + delim + 'mul_fold')
self._AssertInputOpsAre(weights_quant, expected_inputs)
if layer == 'DepthwiseConv2dNative':
- output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
- if delay else '/depthwise_Fold')
+ output_op_name = conv_scope + delim + (
+ 'weights_quant/delayed_quant/Switch_1' if delay else 'depthwise_Fold')
else:
- output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
- if delay else '/' + layer + '_Fold')
+ output_op_name = conv_scope + delim + (
+ 'weights_quant/delayed_quant/Switch_1' if delay else layer + '_Fold')
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
if with_bypass:
conv_quant = graph.get_operation_by_name(
- scope + '/conv_quant/' + quantization_node_name)
+ conv_scope + delim + 'conv_quant/' + quantization_node_name)
self.assertEqual(conv_quant.type, quantization_node_name)
if use_resource:
expected_inputs = [
- scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ conv_scope + delim +
+ 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ conv_scope + delim +
+ 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
]
else:
expected_inputs = [
- scope + '/conv_quant/AssignMinEma',
- scope + '/conv_quant/AssignMaxEma',
+ conv_scope + delim + 'conv_quant/AssignMinEma',
+ conv_scope + delim + 'conv_quant/AssignMaxEma',
]
- expected_inputs.append(scope + '/add_fold')
+ expected_inputs.append(conv_scope + delim + 'add_fold')
self._AssertInputOpsAre(conv_quant, expected_inputs)
output_op_name = (
- scope + '/conv_quant/delayed_quant/Switch_1' if delay else 'test/Add')
+ conv_scope + delim + 'conv_quant/delayed_quant/Switch_1'
+ if delay else scope + 'Add')
self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name])
- act_quant = graph.get_operation_by_name(
- 'test/act_quant/' + quantization_node_name)
+ act_quant = graph.get_operation_by_name(scope + 'act_quant/' +
+ quantization_node_name)
self.assertEqual(act_quant.type, quantization_node_name)
if use_resource:
expected_inputs = [
- 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
]
else:
expected_inputs = [
- 'test/act_quant/AssignMinEma',
- 'test/act_quant/AssignMaxEma',
+ scope + 'act_quant/AssignMinEma',
+ scope + 'act_quant/AssignMaxEma',
]
- expected_inputs.append('test/' + activation_op_name)
+ expected_inputs.append(scope + activation_op_name)
self._AssertInputOpsAre(act_quant, expected_inputs)
- output_op_name = ('test/act_quant/delayed_quant/Switch_1'
- if delay else 'control_dependency')
+ output_op_name = (
+ scope + 'act_quant/delayed_quant/Switch_1'
+ if delay else 'control_dependency')
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
self._AssertIdempotent(graph)
@@ -433,7 +488,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name,
with_bypass, delay, fused_batch_norm,
- use_resource):
+ use_resource, scope):
"""Tests quantization: inputs -> Conv2d with batch norm -> Activation.
Args:
@@ -445,6 +500,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -453,7 +509,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs = array_ops.zeros((batch_size, height, width, depth))
stride = 1 if with_bypass else 2
out_depth = 3 if with_bypass else 32
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
node = conv2d(
inputs,
out_depth, [5, 5],
@@ -463,13 +521,13 @@ class QuantizeTest(test_util.TensorFlowTestCase):
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(fused_batch_norm),
- scope=scope)
+ scope=conv_scope)
# Manually add a bypass (optional) and an activation.
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
@@ -487,7 +545,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name,
with_bypass, delay, fused_batch_norm,
- use_resource):
+ use_resource, scope):
"""Tests quantization: inputs -> FC with batch norm -> Activation.
Args:
@@ -499,6 +557,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -506,7 +565,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
batch_size, depth = 5, 256
inputs = array_ops.zeros((batch_size, depth))
out_depth = 256 if with_bypass else 128
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
node = fully_connected(
inputs,
out_depth,
@@ -514,13 +575,13 @@ class QuantizeTest(test_util.TensorFlowTestCase):
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(fused_batch_norm),
- scope=scope)
+ scope=conv_scope)
# Manually add a bypass (optional) and an activation.
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
@@ -540,7 +601,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
def _TestQuantize_DepthwiseConv2dWithBatchNorm(
self, activation, activation_op_name, with_bypass, delay,
- fused_batch_norm, use_resource):
+ fused_batch_norm, use_resource, scope):
"""Tests quantization: inputs -> DWConv2d with batch norm -> Activation.
Args:
@@ -552,6 +613,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -559,7 +621,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
batch_size, height, width, depth = 5, 128, 128, 3
inputs = array_ops.zeros((batch_size, height, width, depth))
stride = 1 if with_bypass else 2
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
node = separable_conv2d(
inputs,
None, [5, 5],
@@ -570,13 +634,13 @@ class QuantizeTest(test_util.TensorFlowTestCase):
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(fused_batch_norm),
- scope=scope)
+ scope=conv_scope)
# Manually add a bypass (optional) and an activation.
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
@@ -595,7 +659,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
def _TestQuantize_AtrousConvWithBatchNorm(
self, activation, activation_op_name, with_bypass, delay,
- fused_batch_norm, use_resource):
+ fused_batch_norm, use_resource, scope):
"""Tests quantization: inputs -> atrous conv with batch norm -> Activation.
Args:
@@ -607,6 +671,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -614,7 +679,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
batch_size, height, width, depth = 5, 128, 128, 3
inputs = array_ops.zeros((batch_size, height, width, depth))
dilation_rate = 2
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
+
node = separable_conv2d(
inputs,
None, [3, 3],
@@ -625,13 +693,13 @@ class QuantizeTest(test_util.TensorFlowTestCase):
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(fused_batch_norm),
- scope=scope)
+ scope=conv_scope)
# Manually add a bypass (optional) and an activation.
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
@@ -718,6 +786,18 @@ class QuantizeTest(test_util.TensorFlowTestCase):
with open('/tmp/bn_quant_test.pbtxt', 'w') as f:
f.write(str(graph.as_graph_def()))
+ def _GetConvScope(self, scope, with_bypass):
+ if scope is None:
+ scope = ''
+ delim = '/' if scope else ''
+
+ if with_bypass:
+ conv_scope = scope + delim + 'test2'
+ else:
+ conv_scope = scope
+
+ return conv_scope
+
def _BatchNormParams(self, fused=False, force_updates=False):
params = {
'center': True,
diff --git a/tensorflow/contrib/rate/rate_test.py b/tensorflow/contrib/rate/rate_test.py
index 08908104f4..3dee163881 100644
--- a/tensorflow/contrib/rate/rate_test.py
+++ b/tensorflow/contrib/rate/rate_test.py
@@ -46,7 +46,7 @@ class RateTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
r_ = rate.Rate()
a = r_(array_ops.ones([1]), denominator=array_ops.ones([1]))
self.evaluate(variables.global_variables_initializer())
@@ -67,7 +67,7 @@ class RateTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testWhileLoop(self):
- with self.test_session():
+ with self.cached_session():
r_ = rate.Rate()
def body(value, denom, i, ret_rate):
diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
index c3db71359c..3abf7bd6da 100644
--- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
+++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
@@ -22,7 +22,6 @@ from __future__ import print_function
import copy
from tensorflow.contrib.recurrent.python.ops import recurrent
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -62,7 +61,7 @@ class _FunctionalRnnCell(object):
assert initial_state is not None
# TODO(drpng): Dtype needs to be configurable.
- input_dtypes = [dtypes.float32] + _GetDTypesFromStructure(initial_state)
+ input_dtypes = [seq_inputs.dtype] + _GetDTypesFromStructure(initial_state)
# See _index.
like_inputs_t = nest.map_structure(
lambda x: array_ops.stop_gradient(array_ops.gather(x, 0)), seq_inputs)
@@ -144,7 +143,10 @@ class _FunctionalRnnCell(object):
@property
def extended_initial_state(self):
if self._prepend_output:
- return [array_ops.zeros(self._output_shape), self._state_template]
+ return [array_ops.zeros(
+ self._output_shape,
+ dtype=_GetDTypesFromStructure(self._state_template)[0]),
+ self._state_template]
else:
# The base case, where the output is just the hidden state.
return self._state_template
@@ -185,7 +187,7 @@ def _ApplyLengthsToBatch(sequence_lengths, tf_output):
lengths = array_ops.tile(
array_ops.reshape(sequence_lengths, [-1, 1]), [1, max_time])
is_less = math_ops.cast(
- math_ops.less(output_time, lengths), dtype=dtypes.float32)
+ math_ops.less(output_time, lengths), dtype=tf_output.dtype)
keep_mask = array_ops.tile(
array_ops.expand_dims(is_less, -1),
[1, 1, vector_size])
@@ -217,7 +219,7 @@ def _PickFinalStateFromHistory(acc_state, sequence_length):
def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell,
- total_time, inputs_lengths):
+ total_time, inputs_lengths, is_reversed):
"""Post-process output of recurrent.
This function takes the accumulated extended state and extracts the requested
@@ -226,6 +228,8 @@ def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell,
When `inputs_lengths` has been set, it extracts the output from the
accumulated state. It also sets outputs past.
+ When `is_reversed` is true, the output will be reversed in this function.
+
It also sets the static shape information.
Args:
@@ -236,11 +240,12 @@ def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell,
func_cell: The functional wrapper around the cell.
total_time: A scalar integer tensor.
inputs_lengths: An integer tensor with one entry per input.
+ is_reversed: A boolean to indicate if the sequence is reversed.
Returns:
A tuple with the outputs at each time, and the final state.
"""
- if inputs_lengths is None:
+ if inputs_lengths is None or is_reversed:
flat_final_state = func_cell.MaybeRemoveOutputFromState(
nest.flatten(extended_final_state))
tf_state = nest.pack_sequence_as(func_cell.state_template, flat_final_state)
@@ -254,21 +259,28 @@ def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell,
tf_state = _PickFinalStateFromHistory(acc_state, inputs_lengths)
output_from_state = func_cell.GetOutputFromState(extended_acc_state)
+ if is_reversed:
+ output_from_state = array_ops.reverse(output_from_state, [0])
tf_output = array_ops.transpose(output_from_state, [1, 0, 2])
tf_output.set_shape(
[func_cell.output_shape[0], total_time, func_cell.output_shape[1]])
if inputs_lengths is not None:
# Need set the outputs to zero.
tf_output = _ApplyLengthsToBatch(inputs_lengths, tf_output)
- # tf_output = array_ops.zeros([4, 3, 5])
_SetShapeFromTemplate(tf_state, func_cell.state_template)
return tf_output, tf_state
# pylint: disable=invalid-name
-def functional_rnn(cell, inputs, sequence_length=None,
- initial_state=None, dtype=None, time_major=False,
- scope=None, use_tpu=False):
+def functional_rnn(cell,
+ inputs,
+ sequence_length=None,
+ initial_state=None,
+ dtype=None,
+ time_major=False,
+ scope=None,
+ use_tpu=False,
+ reverse=False):
"""Same interface as `tf.nn.dynamic_rnn`."""
with variable_scope.variable_scope(scope or 'rnn'):
if not time_major:
@@ -283,33 +295,41 @@ def functional_rnn(cell, inputs, sequence_length=None,
max_length = math_ops.reduce_max(sequence_length)
else:
max_length = None
+ if reverse:
+ inputs = array_ops.reverse(inputs, [0])
extended_acc_state, extended_final_state = recurrent.Recurrent(
theta=func_cell.theta,
state0=func_cell.extended_initial_state,
inputs=inputs,
cell_fn=func_cell.cell_step,
max_input_length=max_length,
- use_tpu=use_tpu)
+ use_tpu=use_tpu,
+ aligned_end=reverse)
+
tf_output, tf_state = _PostProcessOutput(
- extended_acc_state, extended_final_state, func_cell,
- inputs_flat[0].shape[0], sequence_length)
+ extended_acc_state,
+ extended_final_state,
+ func_cell,
+ inputs_flat[0].shape[0],
+ sequence_length,
+ is_reversed=reverse)
if time_major:
tf_output = array_ops.transpose(tf_output, [1, 0, 2])
return tf_output, tf_state
-def bidirectional_functional_rnn(
- cell_fw,
- cell_bw,
- inputs,
- initial_state_fw=None,
- initial_state_bw=None,
- dtype=None,
- sequence_length=None,
- time_major=False,
- use_tpu=False,
- scope=None):
+def bidirectional_functional_rnn(cell_fw,
+ cell_bw,
+ inputs,
+ initial_state_fw=None,
+ initial_state_bw=None,
+ dtype=None,
+ sequence_length=None,
+ time_major=False,
+ use_tpu=False,
+ fast_reverse=False,
+ scope=None):
"""Creates a bidirectional recurrent neural network.
Performs fully dynamic unrolling of inputs in both directions. Built to be API
@@ -340,6 +360,10 @@ def bidirectional_functional_rnn(
use_tpu: Whether to enable TPU-compatible operation. If True, does not truly
reverse `inputs` in the backwards RNN. Once b/69305369 is fixed, we can
remove this flag.
+ fast_reverse: Whether to use fast tf.reverse to replace tf.reverse_sequence.
+ This is only possible when either all sequence lengths are the same inside
+ the batch, or when the cell function does not change the state on padded
+ input.
scope: An optional scope name for the dynamic RNN.
Returns:
@@ -388,17 +412,29 @@ def bidirectional_functional_rnn(
return array_ops.reverse(input_, axis=[seq_dim])
with variable_scope.variable_scope('bw') as bw_scope:
- inputs_reverse = _reverse(
- inputs, seq_lengths=sequence_length,
- seq_dim=time_dim, batch_dim=batch_dim)
- tmp, output_state_bw = functional_rnn(
- cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
- initial_state=initial_state_bw, dtype=dtype,
- time_major=time_major, scope=bw_scope, use_tpu=use_tpu)
-
- output_bw = _reverse(
- tmp, seq_lengths=sequence_length,
- seq_dim=time_dim, batch_dim=batch_dim)
+ if not fast_reverse:
+ inputs = _reverse(
+ inputs,
+ seq_lengths=sequence_length,
+ seq_dim=time_dim,
+ batch_dim=batch_dim)
+ output_bw, output_state_bw = functional_rnn(
+ cell=cell_bw,
+ inputs=inputs,
+ sequence_length=sequence_length,
+ initial_state=initial_state_bw,
+ dtype=dtype,
+ time_major=time_major,
+ scope=bw_scope,
+ use_tpu=use_tpu,
+ reverse=fast_reverse)
+
+ if not fast_reverse:
+ output_bw = _reverse(
+ output_bw,
+ seq_lengths=sequence_length,
+ seq_dim=time_dim,
+ batch_dim=batch_dim)
outputs = (output_fw, output_bw)
output_states = (output_state_fw, output_state_bw)
diff --git a/tensorflow/contrib/recurrent/python/ops/recurrent.py b/tensorflow/contrib/recurrent/python/ops/recurrent.py
index 4f289e0c85..f51de755d8 100644
--- a/tensorflow/contrib/recurrent/python/ops/recurrent.py
+++ b/tensorflow/contrib/recurrent/python/ops/recurrent.py
@@ -274,8 +274,16 @@ def _ConvertNoneGradientToZeros(xs, dxs):
class _Recurrent(object):
"""A helper class to construct a recurrent neural net."""
- def __init__(self, cell_fn, cell_grad, theta, state0, inputs,
- max_input_length, extras, use_tpu):
+ def __init__(self,
+ cell_fn,
+ cell_grad,
+ theta,
+ state0,
+ inputs,
+ max_input_length,
+ extras,
+ use_tpu,
+ aligned_end=False):
"""RNN helper class.
Args:
@@ -294,6 +302,8 @@ class _Recurrent(object):
and shapes of this `extras`.
use_tpu: A boolean indicating whether the computation is mean to
run on a TPU.
+ aligned_end: A boolean indicating whether the sequence is aligned at
+ the end.
"""
self._theta = theta
self._state = state0
@@ -303,6 +313,7 @@ class _Recurrent(object):
self._cell_fn = cell_fn
self._cell_grad = cell_grad
self._extras = extras
+ self._aligned_end = aligned_end
# pylint: disable=unbalanced-tuple-unpacking
@@ -417,10 +428,11 @@ class _Recurrent(object):
acc_state = _EmptyAcc(slen_dim, state0)
acc_extras = _EmptyAcc(slen_dim, extras)
- dev_t = array_ops.constant(0, dtype=dev_t_type)
+ t = slen_dim - max_input_length if self._aligned_end else 0
+ dev_t = math_ops.to_int32(t) if use_tpu else math_ops.to_int64(t)
run = functional_ops.For(
- start=0,
- limit=max_input_length,
+ start=t,
+ limit=slen_dim if self._aligned_end else max_input_length,
delta=1,
inputs=[dev_t] + _Flatten(
[theta, state0, inputs, acc_state, acc_extras]),
@@ -551,13 +563,16 @@ class _Recurrent(object):
d_theta = _EmptyLike(theta)
d_inputs = _EmptyLike(inputs)
+ slen_dim = _SeqLenDim(inputs)
+
# Loop backwards. Note the loop's limit is open-ended, so goes through
# t=0.
- t = max_input_length - 1
+ t = slen_dim - 1 if self._aligned_end else max_input_length - 1
dev_t = math_ops.to_int32(t) if use_tpu else math_ops.to_int64(t)
+ limit = slen_dim - max_input_length - 1 if self._aligned_end else -1
run = functional_ops.For(
start=t,
- limit=-1,
+ limit=limit,
delta=-1,
inputs=[dev_t] + _Flatten([
theta, state0, inputs, acc_state, acc_extras, d_theta, d_state1,
@@ -641,7 +656,8 @@ def Recurrent(theta,
cell_grad=None,
extras=None,
max_input_length=None,
- use_tpu=False):
+ use_tpu=False,
+ aligned_end=False):
"""Compute a recurrent neural net.
Roughly, Recurrent() computes the following:
@@ -684,6 +700,8 @@ def Recurrent(theta,
truncate the computation if the inputs have been allocated to a
larger size. A scalar tensor.
use_tpu: whether or not we are on TPU.
+ aligned_end: A boolean indicating whether the sequence is aligned at
+ the end.
Returns:
accumulate_state and the final state.
@@ -717,4 +735,5 @@ def Recurrent(theta,
inputs=inputs,
max_input_length=max_input_length,
extras=extras,
- use_tpu=use_tpu).Compute()
+ use_tpu=use_tpu,
+ aligned_end=aligned_end).Compute()
diff --git a/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py b/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py
index 6253f96315..e30e7255fa 100644
--- a/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py
+++ b/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py
@@ -210,7 +210,7 @@ class ResamplerTest(test.TestCase):
# Input data shape is not defined over a 2D grid, i.e. its shape is not like
# (batch_size, data_height, data_width, data_channels).
- with self.test_session() as sess:
+ with self.cached_session() as sess:
data_shape = (batch_size, data_height, data_width, data_depth,
data_channels)
data = np.zeros(data_shape)
@@ -225,7 +225,7 @@ class ResamplerTest(test.TestCase):
sess.run(outputs)
# Warp tensor must be at least a matrix, with shape [batch_size, 2].
- with self.test_session() as sess:
+ with self.cached_session() as sess:
data_shape = (batch_size, data_height, data_width, data_channels)
data = np.zeros(data_shape)
warp_shape = (batch_size,)
@@ -238,7 +238,7 @@ class ResamplerTest(test.TestCase):
sess.run(outputs)
# The batch size of the data and warp tensors must be the same.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
data_shape = (batch_size, data_height, data_width, data_channels)
data = np.zeros(data_shape)
warp_shape = (batch_size+1, warp_height, warp_width, 2)
@@ -252,7 +252,7 @@ class ResamplerTest(test.TestCase):
# The warp tensor must contain 2D coordinates, i.e. its shape last dimension
# must be 2.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
data_shape = (batch_size, data_height, data_width, data_channels)
data = np.zeros(data_shape)
warp_shape = (batch_size, warp_height, warp_width, 3)
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
index bf699db3ed..f31ad53d3c 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
@@ -163,8 +163,8 @@ class TestStateSaverWithCounters(TestStateSaver):
def __init__(self, batch_size, state_size):
super(TestStateSaverWithCounters, self).__init__(batch_size, state_size)
- self._num_state_calls = variables_lib.Variable(0)
- self._num_save_state_calls = variables_lib.Variable(0)
+ self._num_state_calls = variables_lib.VariableV1(0)
+ self._num_save_state_calls = variables_lib.VariableV1(0)
def state(self, name):
with ops_lib.control_dependencies(
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
index 1c23c28860..0d615923e0 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
@@ -49,7 +49,7 @@ class RpcOpTestBase(object):
return rpc_op.try_rpc(*args, protocol=self._protocol, **kwargs)
def testScalarHostPortRpc(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = (
test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
response_tensors = self.rpc(
@@ -63,7 +63,7 @@ class RpcOpTestBase(object):
self.assertAllEqual([2, 3, 4], response_message.values)
def testScalarHostPortTryRpc(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = (
test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
response_tensors, status_code, status_message = self.try_rpc(
@@ -83,7 +83,7 @@ class RpcOpTestBase(object):
self.assertEqual(b'', status_message_values)
def testEmptyHostPortRpc(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = []
response_tensors = self.rpc(
method=self.get_method_name('Increment'),
@@ -98,7 +98,7 @@ class RpcOpTestBase(object):
'/InvalidService.Increment',
self.get_method_name('InvalidMethodName')
]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesOpError(self.invalid_method_string):
sess.run(self.rpc(method=method, address=self._address, request=''))
@@ -111,7 +111,7 @@ class RpcOpTestBase(object):
def testInvalidAddress(self):
# This covers the case of address='' and address='localhost:293874293874'
address = 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.UnavailableError):
sess.run(
self.rpc(
@@ -128,7 +128,7 @@ class RpcOpTestBase(object):
self.connect_failed_string in status_message_value.decode('ascii'))
def testAlwaysFailingMethod(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
response_tensors = self.rpc(
method=self.get_method_name('AlwaysFailWithInvalidArgument'),
address=self._address,
@@ -150,7 +150,7 @@ class RpcOpTestBase(object):
self.assertTrue(I_WARNED_YOU in status_message_value.decode('ascii'))
def testSometimesFailingMethodWithManyRequests(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Fail hard by default.
response_tensors = self.rpc(
method=self.get_method_name('SometimesFailWithInvalidArgument'),
@@ -179,7 +179,7 @@ class RpcOpTestBase(object):
self.assertAllEqual(expected_message_values, status_message_values)
def testVecHostPortRpc(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = [
test_example_pb2.TestCase(
values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
@@ -197,7 +197,7 @@ class RpcOpTestBase(object):
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testVecHostPortManyParallelRpcs(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = [
test_example_pb2.TestCase(
values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
@@ -219,7 +219,7 @@ class RpcOpTestBase(object):
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testVecHostPortRpcUsingEncodeAndDecodeProto(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = encode_proto_op.encode_proto(
message_type='tensorflow.contrib.rpc.TestCase',
field_names=['values'],
@@ -241,7 +241,7 @@ class RpcOpTestBase(object):
for i in range(20)], response_shape_values)
def testVecHostPortRpcCancelsUponSessionTimeOutWhenSleepingForever(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = [''] * 25 # This will launch 25 RPC requests.
response_tensors = self.rpc(
method=self.get_method_name('SleepForever'),
@@ -254,7 +254,7 @@ class RpcOpTestBase(object):
sess.run(response_tensors, options=options)
def testVecHostPortRpcCancelsUponConfiguredTimeOutWhenSleepingForever(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = [''] * 25 # This will launch 25 RPC requests.
response_tensors = self.rpc(
method=self.get_method_name('SleepForever'),
@@ -265,7 +265,7 @@ class RpcOpTestBase(object):
sess.run(response_tensors)
def testTryRpcPropagatesDeadlineErrorWithSometimesTimingOutRequests(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
response_tensors, status_code, status_message = self.try_rpc(
method=self.get_method_name('SometimesSleepForever'),
timeout_in_ms=1000,
@@ -281,7 +281,7 @@ class RpcOpTestBase(object):
def testTryRpcWithMultipleAddressesSingleRequest(self):
flatten = lambda x: list(itertools.chain.from_iterable(x))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
addresses = flatten([[
self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
] for _ in range(10)])
@@ -301,7 +301,7 @@ class RpcOpTestBase(object):
def testTryRpcWithMultipleMethodsSingleRequest(self):
flatten = lambda x: list(itertools.chain.from_iterable(x))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
methods = flatten(
[[self.get_method_name('Increment'), 'InvalidMethodName']
for _ in range(10)])
@@ -319,7 +319,7 @@ class RpcOpTestBase(object):
def testTryRpcWithMultipleAddressesAndRequests(self):
flatten = lambda x: list(itertools.chain.from_iterable(x))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
addresses = flatten([[
self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
] for _ in range(10)])
diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD
index 4ca5274b2e..291ff83791 100644
--- a/tensorflow/contrib/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/BUILD
@@ -92,10 +92,7 @@ py_library(
"//tensorflow/python:platform",
"//tensorflow/python:saver",
"//tensorflow/python:util",
- "//tensorflow/python/estimator",
- "//tensorflow/python/estimator:export",
- "//tensorflow/python/estimator:keras",
- "//tensorflow/python/estimator:model_fn",
+ "//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/keras:engine",
"//tensorflow/python/saved_model",
],
@@ -111,6 +108,7 @@ py_test(
":keras_saved_model",
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
+ "//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/keras",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
index 12dd72a95b..060c504523 100644
--- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
+++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
@@ -269,7 +269,7 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
def testSaveAndLoadSavedModelExport(
self, model_builder, uses_learning_phase, optimizer, train_before_export):
saved_model_path = self._save_model_dir()
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
input_arr = np.random.random((1, 3))
target_arr = np.random.random((1, 3))
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index f2c43f30d4..1f3b533de9 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -919,31 +919,28 @@ class AttentionWrapperTest(test.TestCase):
wrapper.BahdanauAttention, wrapper.LuongAttention)
expected_final_output = BasicDecoderOutput(
- rnn_output=ResultSummary(shape=(5, 3, 20),
- dtype=dtype('float32'),
- mean=0.11723966),
- sample_id=ResultSummary(shape=(5, 3),
- dtype=dtype('int32'),
- mean=9.2666666666666675))
+ rnn_output=ResultSummary(
+ shape=(5, 3, 20), dtype=dtype('float32'), mean=0.11723966),
+ sample_id=ResultSummary(
+ shape=(5, 3), dtype=dtype('int32'), mean=7.266666666666667))
expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple(
- c=ResultSummary(shape=(5, 9),
- dtype=dtype('float32'),
- mean=-0.003545674),
- h=ResultSummary(shape=(5, 9),
- dtype=dtype('float32'),
- mean=-0.0018327223)),
- attention=ResultSummary(shape=(5, 20),
- dtype=dtype('float32'),
- mean=0.11728073),
+ c=ResultSummary(
+ shape=(5, 9), dtype=dtype('float32'), mean=-0.003545674),
+ h=ResultSummary(
+ shape=(5, 9), dtype=dtype('float32'), mean=-0.0018327223)),
+ attention=ResultSummary(
+ shape=(5, 20), dtype=dtype('float32'), mean=0.11601614207),
time=3,
- alignments=(
- ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
- ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
+ alignments=(ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
alignment_history=(),
- attention_state=(
- ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
- ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)))
+ attention_state=(ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.125)))
expected_final_alignment_history = (
ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125),
ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125))
diff --git a/tensorflow/contrib/session_bundle/bundle_shim.cc b/tensorflow/contrib/session_bundle/bundle_shim.cc
index 4fc36d85ed..c669ced997 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim.cc
+++ b/tensorflow/contrib/session_bundle/bundle_shim.cc
@@ -355,11 +355,15 @@ Status LoadSessionBundleOrSavedModelBundle(
const SessionOptions& session_options, const RunOptions& run_options,
const string& export_dir,
const std::unordered_set<string>& saved_model_tags,
- SavedModelBundle* saved_model_bundle) {
+ SavedModelBundle* saved_model_bundle, bool* is_session_bundle) {
+ if (is_session_bundle != nullptr) {
+ *is_session_bundle = false;
+ }
if (MaybeSavedModelDirectory(export_dir)) {
LOG(INFO)
<< "Attempting to load native SavedModelBundle in bundle-shim from: "
<< export_dir;
+
return LoadSavedModel(session_options, run_options, export_dir,
saved_model_tags, saved_model_bundle);
} else if (IsPossibleExportDirectory(export_dir)) {
@@ -368,6 +372,9 @@ Status LoadSessionBundleOrSavedModelBundle(
LOG(INFO) << "Attempting to up-convert SessionBundle to SavedModelBundle "
"in bundle-shim from: "
<< export_dir;
+ if (is_session_bundle != nullptr) {
+ *is_session_bundle = true;
+ }
return LoadSavedModelFromLegacySessionBundlePath(
session_options, run_options, export_dir, saved_model_bundle);
}
diff --git a/tensorflow/contrib/session_bundle/bundle_shim.h b/tensorflow/contrib/session_bundle/bundle_shim.h
index 4628b6ab1b..7f0f9958d7 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim.h
+++ b/tensorflow/contrib/session_bundle/bundle_shim.h
@@ -59,11 +59,13 @@ Status ConvertSessionBundleToSavedModelBundle(
} // namespace internal
// Loads a SavedModel from either a session-bundle path or a SavedModel bundle
-// path.
+// path. If `is_session_bundle` is not a nullptr, sets it to `true` iff
+// SavedModel was up-converted and loaded from a SessionBundle.
+// `is_session_bundle` value should not be used if error is returned.
Status LoadSessionBundleOrSavedModelBundle(
const SessionOptions& session_options, const RunOptions& run_options,
const string& export_dir, const std::unordered_set<string>& tags,
- SavedModelBundle* bundle);
+ SavedModelBundle* bundle, bool* is_session_bundle = nullptr);
} // namespace serving
} // namespace tensorflow
diff --git a/tensorflow/contrib/session_bundle/bundle_shim_test.cc b/tensorflow/contrib/session_bundle/bundle_shim_test.cc
index 9a1dd9303f..815beb73a0 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim_test.cc
+++ b/tensorflow/contrib/session_bundle/bundle_shim_test.cc
@@ -63,12 +63,16 @@ void ValidateHalfPlusTwo(const SavedModelBundle& saved_model_bundle,
void LoadAndValidateSavedModelBundle(const string& export_dir,
const std::unordered_set<string>& tags,
- const string& signature_def_key) {
+ const string& signature_def_key,
+ bool expect_session_bundle) {
SessionOptions session_options;
RunOptions run_options;
SavedModelBundle saved_model_bundle;
+ bool is_session_bundle = false;
TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle(
- session_options, run_options, export_dir, tags, &saved_model_bundle));
+ session_options, run_options, export_dir, tags, &saved_model_bundle,
+ &is_session_bundle));
+ EXPECT_EQ(expect_session_bundle, is_session_bundle);
const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def;
const auto& signature_def_map = meta_graph_def.signature_def();
@@ -512,7 +516,8 @@ TEST(BundleShimTest, BasicExportSessionBundle) {
const string session_bundle_export_dir =
test_util::TestSrcDirPath(kSessionBundlePath);
LoadAndValidateSavedModelBundle(session_bundle_export_dir, tags,
- kDefaultServingSignatureDefKey);
+ kDefaultServingSignatureDefKey,
+ /*expect_session_bundle=*/true);
// Verify that the named signature is also present.
SessionOptions session_options;
@@ -558,7 +563,8 @@ TEST(BundleShimTest, BasicExportSavedModel) {
const string saved_model_bundle_export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kSavedModelBundlePath);
LoadAndValidateSavedModelBundle(saved_model_bundle_export_dir,
- {kSavedModelTagServe}, "regress_x_to_y");
+ {kSavedModelTagServe}, "regress_x_to_y",
+ /*expect_session_bundle=*/false);
}
// Checks a basic load fails with an invalid export path.
diff --git a/tensorflow/contrib/session_bundle/exporter_test.py b/tensorflow/contrib/session_bundle/exporter_test.py
index 86df425da0..68419ffea0 100644
--- a/tensorflow/contrib/session_bundle/exporter_test.py
+++ b/tensorflow/contrib/session_bundle/exporter_test.py
@@ -64,10 +64,10 @@ class SaveRestoreShardedTest(test.TestCase):
# v2 is an unsaved variable derived from v0 and v1. It is used to
# exercise the ability to run an init op when restoring a graph.
with sess.graph.device("/cpu:0"):
- v0 = variables.Variable(10, name="v0")
+ v0 = variables.VariableV1(10, name="v0")
with sess.graph.device("/cpu:1"):
- v1 = variables.Variable(20, name="v1")
- v2 = variables.Variable(1, name="v2", trainable=False, collections=[])
+ v1 = variables.VariableV1(20, name="v1")
+ v2 = variables.VariableV1(1, name="v2", trainable=False, collections=[])
assign_v2 = state_ops.assign(v2, math_ops.add(v0, v1))
init_op = control_flow_ops.group(assign_v2, name="init_op")
diff --git a/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py b/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
index e4db5f2e3c..e6a0b30567 100644
--- a/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
+++ b/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
@@ -38,7 +38,7 @@ class StatSummarizerTest(test.TestCase):
graph_def = graph.as_graph_def()
ss = pywrap_tensorflow.NewStatSummarizer(graph_def.SerializeToString())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
for _ in range(20):
diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py
index ae8336daaf..807741e05f 100644
--- a/tensorflow/contrib/summary/summary_ops_graph_test.py
+++ b/tensorflow/contrib/summary/summary_ops_graph_test.py
@@ -52,7 +52,7 @@ class GraphFileTest(test_util.TensorFlowTestCase):
summary_ops.histogram('histogram', [1.0], step=1)
summary_ops.image('image', [[[[1.0]]]], step=1)
summary_ops.audio('audio', [[1.0]], 1.0, 1, step=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
sess.run(summary_ops.all_summary_ops())
# The working condition of the ops is tested in the C++ test so we just
@@ -64,7 +64,7 @@ class GraphFileTest(test_util.TensorFlowTestCase):
writer = summary_ops.create_file_writer(logdir, max_queue=0)
with writer.as_default(), summary_ops.always_record_summaries():
summary_ops.scalar('scalar', 2.0, step=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
sess.run(summary_ops.all_summary_ops())
events = summary_test_util.events_from_logdir(logdir)
@@ -77,7 +77,7 @@ class GraphFileTest(test_util.TensorFlowTestCase):
with writer.as_default(), summary_ops.always_record_summaries():
with ops.name_scope('scope'):
summary_ops.scalar('scalar', 2.0, step=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
sess.run(summary_ops.all_summary_ops())
events = summary_test_util.events_from_logdir(logdir)
@@ -90,7 +90,7 @@ class GraphFileTest(test_util.TensorFlowTestCase):
writer = summary_ops.create_file_writer(logdir, max_queue=0)
with writer.as_default(), summary_ops.always_record_summaries():
summary_ops.scalar('scalar', 2.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(summary_ops.summary_writer_initializer_op())
step, _ = sess.run(
@@ -105,7 +105,7 @@ class GraphFileTest(test_util.TensorFlowTestCase):
logdir, max_queue=1, flush_millis=999999)
with writer.as_default(), summary_ops.always_record_summaries():
summary_ops.scalar('scalar', 2.0, step=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
# Note: First tf.Event is always file_version.
@@ -123,7 +123,7 @@ class GraphFileTest(test_util.TensorFlowTestCase):
with writer.as_default(), summary_ops.always_record_summaries():
summary_ops.scalar('scalar', 2.0, step=1)
flush_op = summary_ops.flush()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
# Note: First tf.Event is always file_version.
@@ -157,7 +157,7 @@ class GraphFileTest(test_util.TensorFlowTestCase):
with writer3.as_default():
summary_ops.scalar('three', 3.0, step=3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Run init ops across writers sequentially to avoid race condition.
# TODO(nickfelt): fix race condition in resource manager lookup or create
sess.run(writer1.init())
@@ -191,7 +191,7 @@ class GraphFileTest(test_util.TensorFlowTestCase):
logdir, max_queue=100, flush_millis=1000000)
with writer.as_default():
summary_ops.scalar('one', 1.0, step=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
self.assertEqual(1, get_total()) # file_version Event
@@ -219,7 +219,7 @@ class GraphFileTest(test_util.TensorFlowTestCase):
logdir, max_queue=100, flush_millis=1000000)
with writer.as_default():
summary_ops.scalar('one', 1.0, step=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
self.assertEqual(1, get_total()) # file_version Event
@@ -241,7 +241,7 @@ class GraphDbTest(summary_test_util.SummaryDbTest):
training_util.get_or_create_global_step()
name = 'hi'
graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),))
- with self.test_session():
+ with self.cached_session():
with self.create_db_writer().as_default():
summary_ops.initialize(graph=graph)
six.assertCountEqual(self, [name],
@@ -249,7 +249,7 @@ class GraphDbTest(summary_test_util.SummaryDbTest):
def testScalarSummary(self):
"""Test record_summaries_every_n_global_steps and all_summaries()."""
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
global_step = training_util.get_or_create_global_step()
global_step.initializer.run()
with ops.device('/cpu:0'):
@@ -280,7 +280,7 @@ class GraphDbTest(summary_test_util.SummaryDbTest):
def testScalarSummaryNameScope(self):
"""Test record_summaries_every_n_global_steps and all_summaries()."""
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
global_step = training_util.get_or_create_global_step()
global_step.initializer.run()
with ops.device('/cpu:0'):
@@ -311,7 +311,7 @@ class GraphDbTest(summary_test_util.SummaryDbTest):
self.assertEqual(events[1].summary.value[0].tag, 'scope/my_scalar')
def testSummaryGraphModeCond(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
training_util.get_or_create_global_step()
logdir = tempfile.mkdtemp()
with summary_ops.create_file_writer(
@@ -332,7 +332,7 @@ class GraphDbTest(summary_test_util.SummaryDbTest):
self.assertEqual(events[1].summary.value[0].tag, 'cond/scalar')
def testSummaryGraphModeWhile(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
training_util.get_or_create_global_step()
logdir = tempfile.mkdtemp()
with summary_ops.create_file_writer(
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
index 00c855daa3..398ac314f4 100644
--- a/tensorflow/contrib/tensor_forest/BUILD
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -518,7 +518,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":client_lib",
- "//tensorflow/contrib/estimator:head",
+ "//tensorflow/contrib/estimator:estimator_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics_test.py b/tensorflow/contrib/tensor_forest/client/eval_metrics_test.py
index aa30919167..d49928e3f1 100644
--- a/tensorflow/contrib/tensor_forest/client/eval_metrics_test.py
+++ b/tensorflow/contrib/tensor_forest/client/eval_metrics_test.py
@@ -32,7 +32,7 @@ class EvalMetricsTest(test_util.TensorFlowTestCase):
[0.9, 0.8, 0.2], [0.6, 0.4, 0.8]])
targets = constant_op.constant([[0], [2], [1], [1]])
in_top_2_op, update_op = top_2_fn(probabilities, targets)
- with self.test_session():
+ with self.cached_session():
# initializes internal accuracy vars
variables.local_variables_initializer().run()
# need to call in order to run the in_top_2_op internal operations because
@@ -49,7 +49,7 @@ class EvalMetricsTest(test_util.TensorFlowTestCase):
[0.3, 0.6, 0.9, 0.4, 0.8, 0.6]])
targets = constant_op.constant([3, 0, 2, 5, 1])
in_top_3_op, update_op = top_3_fn(probabilities, targets)
- with self.test_session():
+ with self.cached_session():
# initializes internal accuracy vars
variables.local_variables_initializer().run()
# need to call in order to run the in_top_3_op internal operations because
@@ -61,7 +61,7 @@ class EvalMetricsTest(test_util.TensorFlowTestCase):
predictions = constant_op.constant([0, 1, 3, 6, 5, 2, 7, 6, 4, 9])
targets = constant_op.constant([0, 1, 4, 6, 5, 1, 7, 5, 4, 8])
accuracy_op, update_op = eval_metrics._accuracy(predictions, targets)
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
# need to call in order to run the accuracy_op internal operations because
# it is a streaming function
@@ -74,7 +74,7 @@ class EvalMetricsTest(test_util.TensorFlowTestCase):
targets = constant_op.constant(
[1.0, 4.3, 2.6, 0.5, 1.1, 0.7, 5.1, 3.4, 1.8])
r2_op, update_op = eval_metrics._r2(scores, targets)
- with self.test_session():
+ with self.cached_session():
# initializes internal accuracy vars
variables.local_variables_initializer().run()
# need to call in order to run the r2_op internal operations because
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index 0042d37acd..6e3bfbb9bd 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -446,6 +446,10 @@ class TensorForestEstimator(estimator.Estimator):
Returns:
A `TensorForestEstimator` instance.
"""
+ # Override default number of trainers if config is provided.
+ if num_trainers == 1 and config is not None:
+ num_trainers = max(1, config.num_worker_replicas)
+
super(TensorForestEstimator, self).__init__(
model_fn=get_model_fn(
params.fill(),
@@ -564,6 +568,10 @@ class MultiForestMultiHeadEstimator(estimator.Estimator):
local_eval=False):
"""See TensorForestEstimator.__init__."""
model_fns = []
+ # Override default number of trainers if config is provided.
+ if num_trainers == 1 and config is not None:
+ num_trainers = max(1, config.num_worker_replicas)
+
for i in range(len(params_list)):
params = params_list[i].fill()
model_fns.append(
@@ -709,6 +717,11 @@ class CoreTensorForestEstimator(core_estimator.Estimator):
Returns:
A `TensorForestEstimator` instance.
"""
+ # Override default number of trainers if config is provided.
+ if num_trainers == 1 and config is not None:
+ num_trainers = max(1, config.num_worker_replicas)
+ if trainer_id == 0 and config is not None:
+ trainer_id = config.global_id_in_cluster
super(CoreTensorForestEstimator, self).__init__(
model_fn=get_model_fn(
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
index e429d12e96..0b02bdcb50 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
@@ -27,12 +27,12 @@ from tensorflow.python.platform import googletest
class ScatterAddNdimTest(test_util.TensorFlowTestCase):
def test1dim(self):
- input_data = variables.Variable(
+ input_data = variables.VariableV1(
[1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.])
indices = [[1], [10]]
updates = [100., 200.]
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run()
self.assertAllEqual(
@@ -40,12 +40,12 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase):
input_data.eval())
def test3dim(self):
- input_data = variables.Variable([[[1., 2., 3.], [4., 5., 6.]],
- [[7., 8., 9.], [10., 11., 12.]]])
+ input_data = variables.VariableV1([[[1., 2., 3.], [4., 5., 6.]],
+ [[7., 8., 9.], [10., 11., 12.]]])
indices = [[0, 0, 1], [1, 1, 2]]
updates = [100., 200.]
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run()
self.assertAllEqual([[[1., 102., 3.], [4., 5., 6.]],
@@ -53,21 +53,21 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase):
def testNoUpdates(self):
init_val = [[[1., 2., 3.], [4., 5., 6.]], [[7., 8., 9.], [10., 11., 12.]]]
- input_data = variables.Variable(init_val)
+ input_data = variables.VariableV1(init_val)
indices = []
updates = []
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run()
self.assertAllEqual(init_val, input_data.eval())
def testBadInput(self):
init_val = [[[1., 2., 3.], [4., 5., 6.]], [[7., 8., 9.], [10., 11., 12.]]]
- input_data = variables.Variable(init_val)
+ input_data = variables.VariableV1(init_val)
indices = [[0, 0, 1], [1, 1, 2]]
updates = [100.]
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
with self.assertRaisesOpError(
'Number of updates should be same as number of indices.'):
@@ -75,12 +75,12 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase):
self.assertAllEqual(init_val, input_data.eval())
def testIncompleteIndices(self):
- input_data = variables.Variable([[[1., 2., 3.], [4., 5., 6.]],
- [[7., 8., 9.], [10., 11., 12.]]])
+ input_data = variables.VariableV1([[[1., 2., 3.], [4., 5., 6.]],
+ [[7., 8., 9.], [10., 11., 12.]]])
indices = [[0, 0], [1, 1]]
updates = [[100., 200., 300.], [400., 500., 600.]]
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run()
self.assertAllEqual([[[101., 202., 303.], [4., 5., 6.]],
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
index 1c9c81827e..e0f0c0d4ff 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
@@ -149,7 +149,7 @@ class TensorForestTest(test_util.TensorFlowTestCase):
self.assertTrue(isinstance(probs, ops.Tensor))
self.assertTrue(isinstance(paths, ops.Tensor))
self.assertTrue(isinstance(var, ops.Tensor))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
resources.initialize_resources(resources.shared_resources()).run()
self.assertEquals(probs.eval().shape, (4, 2))
diff --git a/tensorflow/contrib/tensorboard/BUILD b/tensorflow/contrib/tensorboard/BUILD
index 2b6a2b2f3c..7f0b3255ed 100644
--- a/tensorflow/contrib/tensorboard/BUILD
+++ b/tensorflow/contrib/tensorboard/BUILD
@@ -32,7 +32,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":projector",
- ":trace",
],
)
@@ -60,33 +59,3 @@ py_test(
"//tensorflow/python:summary",
],
)
-
-# API methods and protos in `tf.contrib.tensorboard.plugins.trace` package.
-py_library(
- name = "trace",
- srcs = glob(
- ["plugins/trace/**/*.py"],
- exclude = ["**/*test*"],
- ),
- srcs_version = "PY2AND3",
- deps = [
- ":protos_all_py",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:lib",
- "//tensorflow/python:platform",
- ],
-)
-
-py_test(
- name = "trace_test",
- size = "small",
- srcs = ["plugins/trace/trace_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_windows"],
- deps = [
- ":trace",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:platform",
- ],
-)
diff --git a/tensorflow/contrib/tensorboard/db/loader.cc b/tensorflow/contrib/tensorboard/db/loader.cc
index 4d7337a53d..6439328022 100644
--- a/tensorflow/contrib/tensorboard/db/loader.cc
+++ b/tensorflow/contrib/tensorboard/db/loader.cc
@@ -111,10 +111,10 @@ int main(int argc, char* argv[]) {
++records;
}
uint64 elapsed = env->NowMicros() - start;
+ uint64 bps = (elapsed == 0 ? offset : static_cast<uint64>(
+ offset / (elapsed / 1000000.0)));
LOG(INFO) << "Loaded " << AddCommas(offset) << " bytes with "
- << AddCommas(records) << " records at "
- << AddCommas(offset / (elapsed / 1000000)) << " bps";
-
+ << AddCommas(records) << " records at " << AddCommas(bps) << " bps";
return 0;
}
diff --git a/tensorflow/contrib/tensorboard/plugins/__init__.py b/tensorflow/contrib/tensorboard/plugins/__init__.py
index 41aa77910c..4ba469eb52 100644
--- a/tensorflow/contrib/tensorboard/plugins/__init__.py
+++ b/tensorflow/contrib/tensorboard/plugins/__init__.py
@@ -20,4 +20,4 @@ from __future__ import print_function
# Add projects here, they will show up under tf.contrib.tensorboard.plugins
from tensorflow.contrib.tensorboard.plugins import projector
-from tensorflow.contrib.tensorboard.plugins import trace
+
diff --git a/tensorflow/contrib/tensorboard/plugins/trace/trace.py b/tensorflow/contrib/tensorboard/plugins/trace/trace.py
deleted file mode 100644
index 07e5316b8b..0000000000
--- a/tensorflow/contrib/tensorboard/plugins/trace/trace.py
+++ /dev/null
@@ -1,167 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Stores debugging information regarding TensorFlow model."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-import parser
-import re
-import token
-
-from google.protobuf import json_format
-
-from tensorflow.contrib.tensorboard.plugins.trace.trace_info_pb2 import TraceInfo
-from tensorflow.python.framework import ops
-from tensorflow.python.platform import gfile
-
-# List of regex patterns that match files in the core tensorflow library.
-TF_LIB_REGEX_FPATHS = [os.sep + os.path.join('tensorflow', 'python')]
-
-LEFT_TOKENS = [token.LPAR, token.LSQB, token.LBRACE]
-RIGHT_TOKENS = [token.RPAR, token.RSQB, token.RBRACE]
-TOKENS = LEFT_TOKENS + RIGHT_TOKENS
-
-
-def store_trace_info(output_file_path,
- graph=None,
- ignore_regex_fpaths=None):
- """Collects and stores trace information for a TensorFlow model.
-
- The output proto is stored in json format.
-
- Args:
- output_file_path: The path where to store the output proto.
- graph: Optional. The data flow graph. Defaults to `tf.get_default_graph()`.
- ignore_regex_fpaths: Optional. Files whose path matches any of the regexes
- in this list will be ignored. Defaults to patterns that match the core
- tensorflow python library.
- """
- graph = graph or ops.get_default_graph()
-
- if not ignore_regex_fpaths:
- ignore_regex_fpaths = TF_LIB_REGEX_FPATHS
-
- trace_info = TraceInfo()
- # Extract trace information for every op in the graph.
- source_fpaths = set()
- for op in graph.get_operations():
- op_info = trace_info.ops.add()
- op_info.name = op.name
- op_info.op_type = op.type
- op_info.device = op.device
- for trace in op.traceback:
- fname, lineno, _, _ = trace
- # Ignore traces in specified file paths.
- if os.path.isabs(fname) and not _ignore_file_path(fname,
- ignore_regex_fpaths):
- line_trace = op_info.traceback.add()
- line_trace.file_path = fname
- line_trace.line_number = lineno
- source_fpaths.add(fname)
- _add_data_from_tensors(op.inputs, op_info.inputs)
- _add_data_from_tensors(op.outputs, op_info.outputs)
-
- # Read the source files involved in the graph construction.
- for fpath in source_fpaths:
- file_info = trace_info.files.add()
-
- with gfile.Open(fpath, 'r') as f:
- source = f.read()
-
- file_info.file_path = fpath
- file_info.source_code = source
-
- line2start = find_multiline_statements(source)
-
- for key, value in line2start.items():
- file_info.multiline_statements[key] = value
-
- # Make sure the directory for the output file exists.
- output_file_path = os.path.expanduser(output_file_path)
- output_dir = os.path.dirname(output_file_path)
- if not gfile.Exists(output_dir):
- gfile.MakeDirs(output_dir)
-
- # Store the debug information.
- with gfile.Open(output_file_path, 'w') as f:
- f.write(json_format.MessageToJson(trace_info))
-
-
-def find_multiline_statements(source):
- """Parses the python source and finds multiline statements.
-
- Based on counting the number of open and closed parenthesis on each line.
-
- Args:
- source: The source code string.
-
- Returns:
- A dict that maps a line index A to a line index B, where A is the end of a
- multiline statement and B is the start. Line indexing is 0-based.
- """
- # Get the AST.
- tree = parser.suite(source)
- line2paren_count = [0] * (source.count('\n') + 1)
- _count_brackets_braces_parenthesis(tree.totuple(True), line2paren_count)
-
- line2start = {}
- for end in range(len(line2paren_count)):
- if line2paren_count[end] >= 0:
- # This is not the end of a multiline statement.
- continue
- cumulative_paren_count = 0
- for start in range(end, -1, -1):
- cumulative_paren_count += line2paren_count[start]
- if cumulative_paren_count == 0:
- line2start[end] = start
- break
- return line2start
-
-
-def _add_data_from_tensors(tensors, info):
- for t in tensors:
- tensor_info = info.add()
-
- shape = t.get_shape()
- if shape.ndims:
- shape = [(-1 if s is None else s) for s in shape.as_list()]
- tensor_info.shape.extend(shape)
- tensor_info.dtype = t.dtype.name
- tensor_info.num_bytes_per_elem = t.dtype.size
-
- for c in t.consumers():
- tensor_info.consumers.append(c.name)
-
-
-def _ignore_file_path(fname, ignore_regex_fpaths):
- for regex_pattern in ignore_regex_fpaths:
- if re.search(regex_pattern, fname):
- return True
- return False
-
-
-def _count_brackets_braces_parenthesis(node, line2par):
- if isinstance(node[1], tuple):
- for child in node[1:]:
- _count_brackets_braces_parenthesis(child, line2par)
- else:
- tok = node[0]
- if tok in TOKENS:
- lineno = node[2]
- line2par[lineno - 1] += (1 if tok in LEFT_TOKENS else -1)
- return line2par
diff --git a/tensorflow/contrib/tensorboard/plugins/trace/trace_info.proto b/tensorflow/contrib/tensorboard/plugins/trace/trace_info.proto
deleted file mode 100644
index 9f20becb0f..0000000000
--- a/tensorflow/contrib/tensorboard/plugins/trace/trace_info.proto
+++ /dev/null
@@ -1,60 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-syntax = "proto3";
-
-package tensorflow.contrib.tensorboard;
-
-message TraceInfo {
- repeated OpInfo ops = 1;
- repeated FileInfo files = 2;
-}
-
-message OpInfo {
- string name = 1;
- string op_type = 2;
- string device = 3;
- repeated LineTrace traceback = 4;
- repeated TensorInfo inputs = 5;
- repeated TensorInfo outputs = 6;
-}
-
-message LineTrace {
- // Absolute file path.
- string file_path = 1;
- // 1-based line number.
- uint32 line_number = 2;
-}
-
-message TensorInfo {
- // Size of the tensor for each dimension. Value of -1 denotes "unknown"
- // size for that dimension.
- repeated int32 shape = 1;
- // The data type of the tensor.
- string dtype = 2;
- // Number of bytes per element in the tensor.
- uint32 num_bytes_per_elem = 3;
- // List of operation names that consume this tensor.
- repeated string consumers = 4;
-}
-
-message FileInfo {
- // Absolute file path to the source code.
- string file_path = 1;
- string source_code = 2;
- // Map from end of statement to start of statement. End and start are 0-based
- // line indexes.
- map<uint32, uint32> multiline_statements = 3;
-}
diff --git a/tensorflow/contrib/tensorboard/plugins/trace/trace_test.py b/tensorflow/contrib/tensorboard/plugins/trace/trace_test.py
deleted file mode 100644
index d580f04c5f..0000000000
--- a/tensorflow/contrib/tensorboard/plugins/trace/trace_test.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tensorflow.contrib.tensorboard.plugins.trace package."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import tempfile
-
-from google.protobuf import json_format
-
-from tensorflow.contrib.tensorboard.plugins import trace
-from tensorflow.python.framework import constant_op
-from tensorflow.python.platform import gfile
-from tensorflow.python.platform import test
-
-
-class TraceTest(test.TestCase):
-
- def setUp(self):
- self._temp_dir = tempfile.mkdtemp()
- self._temp_trace_json = self._temp_dir + 'trace.json'
-
- def tearDown(self):
- gfile.DeleteRecursively(self._temp_dir)
-
- def testEmptyGraph(self):
- trace_info = self._store_and_read_trace_info()
- self.assertEqual(len(trace_info.ops), 0)
-
- def testHasSourceCodeOfThisFile(self):
- constant_op.constant(0)
- trace_info = self._store_and_read_trace_info()
-
- self.assertTrue(trace_info.files)
- for file_info in trace_info.files:
- if file_info.file_path.endswith('trace_test.py'):
- return
- self.fail('trace_test file not found in the trace info json')
-
- def testHasTheConstantOp(self):
- constant_op.constant(0)
- trace_info = self._store_and_read_trace_info()
-
- self.assertTrue(trace_info.ops)
-
- for op in trace_info.ops:
- if op.op_type == 'Const':
- return
- self.fail('Could not find operation of type `Const` in the graph')
-
- def testMultilineStatements(self):
- source = """def test():
- a(4,
- 3,
- 1)
-
- b(3, 4, 5)
-
- c((4, 3),
- (),
- )
- """
- line2start = trace.find_multiline_statements(source)
-
- self.assertEqual(line2start[3], 1)
- self.assertEqual(line2start[9], 7)
- self.assertEqual(len(line2start), 2)
-
- def _store_and_read_trace_info(self):
- trace.store_trace_info(self._temp_trace_json)
- trace_info = trace.TraceInfo()
-
- with gfile.Open(self._temp_trace_json) as f:
- text = f.read()
- json_format.Parse(text, trace_info)
-
- return trace_info
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 4ea7216ef2..9e8979bce4 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -444,6 +444,7 @@ cuda_py_test(
cuda_py_tests(
name = "tf_trt_integration_test",
srcs = [
+ "test/base_test.py",
"test/batch_matmul_test.py",
"test/biasadd_matmul_test.py",
"test/binary_tensor_weight_broadcast_test.py",
@@ -470,26 +471,6 @@ cuda_py_tests(
],
)
-cuda_py_tests(
- name = "base_test",
- srcs = [
- "test/base_test.py",
- ],
- additional_deps = [
- ":tf_trt_integration_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_test_lib",
- ],
- tags = [
- "manual",
- "no_cuda_on_cpu_tap",
- "no_gpu",
- "no_windows",
- "nomac",
- "notap",
- ],
-)
-
cc_library(
name = "utils",
srcs = ["convert/utils.cc"],
diff --git a/tensorflow/contrib/tensorrt/README.md b/tensorflow/contrib/tensorrt/README.md
index 687dee07e1..caf8b6db0d 100644
--- a/tensorflow/contrib/tensorrt/README.md
+++ b/tensorflow/contrib/tensorrt/README.md
@@ -26,4 +26,4 @@ available. An example use can be found in test/test_tftrt.py script
In order to make use of TensorRT integration, you will need a local installation
of TensorRT 3.0.4 from the [NVIDIA Developer website](https://developer.nvidia.com/tensorrt).
Installation instructions for compatibility with TensorFlow are provided on the
-[TensorFlow Installation page](https://www.tensorflow.org/install/install_linux#nvidia_requirements_to_run_tensorflow_with_gpu_support).
+[TensorFlow GPU support](https://www.tensorflow.org/install/gpu) guide.
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index b019c99882..7ad9bf22d3 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -678,7 +678,7 @@ tensorflow::Status CreateTRTNode(const std::vector<EngineInfo>& infos, int pos,
// Function to construct a funcdef from the segment and add it to the graph.
tensorflow::Status RegisterSegmentFunctionToFunctionLibrary(
tensorflow::Graph* graph, const tensorflow::GraphDef& segment,
- const string& name) {
+ const string& engine_name) {
tensorflow::Graph sgraph(graph->flib_def());
tensorflow::GraphConstructorOptions gcopts;
TF_RETURN_IF_ERROR(
@@ -761,9 +761,9 @@ tensorflow::Status RegisterSegmentFunctionToFunctionLibrary(
tensorflow::FunctionDefLibrary fdeflib;
auto native_segment = fdeflib.add_function();
TF_RETURN_IF_ERROR(tensorflow::GraphToFunctionDef(
- sgraph, StrCat(name, "_native_segment"), native_segment));
+ sgraph, StrCat(engine_name, "_native_segment"), native_segment));
if (VLOG_IS_ON(7)) {
- VLOG(7) << name << " Function_Def ";
+ VLOG(7) << engine_name << " Function_Def ";
VLOG(7) << native_segment->DebugString();
}
VLOG(1) << "Adding funcdef to graphlib";
@@ -780,12 +780,12 @@ std::pair<int, tensorflow::Allocator*> GetDeviceAndAllocator(
// If device is not set, use the first found GPU device for the conversion.
for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) {
TfGpuId tf_gpu_id(tf_gpu_id_value);
- CudaGpuId cuda_gpu_id;
- Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+ PlatformGpuId platform_gpu_id;
+ Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
if (s.ok()) {
VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
- << cuda_gpu_id.value();
- cuda_device_id = cuda_gpu_id.value();
+ << platform_gpu_id.value();
+ cuda_device_id = platform_gpu_id.value();
GPUOptions gpu_options;
// If the TF to Cuda gpu id mapping exist, the device and corresponding
// allocator must have been initialized already, so the
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index c98b07ad8b..0ce891782e 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -693,8 +693,15 @@ class Converter {
// TODO(jie): tf protobuf seems to be omitting the :0 suffix
string output_name = node_def.name();
if (i != 0) output_name = StrCat(output_name, ":", i);
+ // We need to check the name before setting it. For Identity op where the
+ // output is the input, if its input is one of the engine input, setting
+ // the name here will overwrite engine input bindings which will cause
+ // runtime error.
if (output.is_tensor()) {
- output.tensor()->setName(output_name.c_str());
+ const char* tensor_name = output.tensor()->getName();
+ if (tensor_name == nullptr || std::strlen(tensor_name) == 0) {
+ output.tensor()->setName(output_name.c_str());
+ }
}
VLOG(2) << "Adding out tensor " << output_name << ": "
<< output.DebugString();
@@ -779,12 +786,11 @@ class Converter {
// skip control nodes
if (input_name[0] == '^') continue;
string name = input_name;
- auto first = name.find_first_of(':');
- // TODO(aaroey): why removing the colon but not the zero? A bug?
+ auto last = name.find_last_of(':');
// TODO(aaroey): use TensorId
- if (first != string::npos && first + 2 == name.size() &&
- name[first + 1] == '0') {
- name.erase(first);
+ if (last != string::npos && last + 2 == name.size() &&
+ name[last + 1] == '0') {
+ name.erase(last);
}
if (trt_tensors_.count(name)) {
@@ -2697,7 +2703,6 @@ tensorflow::Status ConvertGraphDefToEngine(
TrtUniquePtrType<nvinfer1::IBuilder> builder(
nvinfer1::createInferBuilder(*logger));
builder->setMaxBatchSize(max_batch_size);
- // TODO(aaroey): use the allocator to allocate the TRT workspace.
builder->setMaxWorkspaceSize(max_workspace_size_bytes);
#if NV_TENSORRT_MAJOR > 3
builder->setGpuAllocator(allocator);
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 2b42d81f47..88cf8d5980 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -565,21 +565,22 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources(
new TRTInt8Calibrator(device_buffers_, batch_size, name()));
const string label(name());
auto segment_graph = &segment_graph_;
- const int cuda_gpu_id = ctx->device()->tensorflow_gpu_device_info()->gpu_id;
- if (cuda_gpu_id < 0) {
+ const int platform_gpu_id =
+ ctx->device()->tensorflow_gpu_device_info()->gpu_id;
+ if (platform_gpu_id < 0) {
LOG(ERROR) << "Can't get gpu_device_info from context->device()";
return tensorflow::errors::InvalidArgument(
"Context->device doesn't contain device info!");
}
const int64 workspace_size_bytes = workspace_size_;
cres->thr_.reset(new std::thread([cres, label, segment_graph, shapes,
- cuda_gpu_id, workspace_size_bytes]() {
- VLOG(0) << "Starting calibration thread on device " << cuda_gpu_id
+ platform_gpu_id, workspace_size_bytes]() {
+ VLOG(0) << "Starting calibration thread on device " << platform_gpu_id
<< ", Calibration Resource @ " << cres;
- auto err = cudaSetDevice(cuda_gpu_id);
+ auto err = cudaSetDevice(platform_gpu_id);
if (err != cudaSuccess) {
// TODO(aaroey): should return error here.
- LOG(ERROR) << "Couldn't set cuda device to " << cuda_gpu_id
+ LOG(ERROR) << "Couldn't set cuda device to " << platform_gpu_id
<< " in calibration thread";
}
// ConvertGraphDefToEngine() will try to build the engine. This thread
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert_test.py b/tensorflow/contrib/tensorrt/python/trt_convert_test.py
index 118a6680fd..52cb0bd9f9 100644
--- a/tensorflow/contrib/tensorrt/python/trt_convert_test.py
+++ b/tensorflow/contrib/tensorrt/python/trt_convert_test.py
@@ -94,7 +94,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
with g.device("/GPU:0"):
inp = array_ops.placeholder(
dtype=dtypes.float32, shape=[None, 1, 1], name="input")
- var = variables.Variable([[[1.0]]], dtype=dtypes.float32, name="v1")
+ var = variables.VariableV1([[[1.0]]], dtype=dtypes.float32, name="v1")
add = inp + var.value()
mul = inp * add
add = mul + add
@@ -104,7 +104,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
def _GetGraphDef(self):
"""Get the graph def for testing."""
g, var, _, _ = self._GetGraph()
- with self.test_session(graph=g, config=self._GetConfigProto()) as sess:
+ with self.session(graph=g, config=self._GetConfigProto()) as sess:
sess.run(var.initializer)
graph_def = graph_util.convert_variables_to_constants(
sess, g.as_graph_def(add_shapes=True), ["output"])
@@ -128,7 +128,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
outputs={"myoutput": utils.build_tensor_info(out)},
method_name=signature_constants.PREDICT_METHOD_NAME)
saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir)
- with self.test_session(graph=g, config=self._GetConfigProto()) as sess:
+ with self.session(graph=g, config=self._GetConfigProto()) as sess:
sess.run(var.initializer)
saved_model_builder.add_meta_graph_and_variables(
sess, [tag_constants.SERVING],
diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc
index d8f97bfbbc..a9425864dd 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc
+++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc
@@ -27,12 +27,16 @@ namespace tensorflow {
namespace tensorrt {
// std::align is not supported, so this method mimic its behavior.
-void* Align(size_t alignment, size_t size, void*& ptr, size_t& space) {
- QCHECK_GT(alignment, 0) << "alignment must be greater than 0.";
+//
+// NOTE(aaroey): according to the TensorRT API,
+// nvinfer1::IGpuAllocator::allocate() uses uint64_t type for size and alignment
+// parameters, so here we use the same type to make it compatible.
+void* Align(uint64_t alignment, uint64_t size, void*& ptr, uint64_t& space) {
+ QCHECK_GT(alignment, 0ul) << "alignment must be greater than 0.";
QCHECK_EQ(0, alignment & (alignment - 1)) << "Alignment must be power of 2.";
- QCHECK_GT(size, 0) << "size must be greater than 0.";
+ QCHECK_GT(size, 0ul) << "size must be greater than 0.";
QCHECK(ptr) << "ptr must not be nullptr.";
- QCHECK_GT(space, 0) << "space must be greater than 0.";
+ QCHECK_GT(space, 0ul) << "space must be greater than 0.";
const uintptr_t ptr_val = reinterpret_cast<uintptr_t>(ptr);
QCHECK_GE(ptr_val + space, ptr_val) << "Provided space overflows.";
@@ -67,12 +71,16 @@ void TRTCudaAllocator::free(void* memory) { cudaFree(memory); }
void* TRTDeviceAllocator::allocate(uint64_t size, uint64_t alignment,
uint32_t flags) {
+ if (size == 0) return nullptr;
// WAR for allocator alignment requirement. Certain cuda API calls require GPU
// memory with alignemtn to cudaDeviceProp::textureAlignment.
// See issue #20856
alignment = 512;
assert((alignment & (alignment - 1)) == 0); // zero or a power of 2.
- size_t total_size = size + alignment;
+ uint64_t total_size = size + alignment;
+ // TODO(aaroey): AllocateRaw takes size_t size as input, so it'll produce
+ // unexpected result when TRT tries to allocate more bytes than size_t can
+ // carry. Fix this.
void* mem = allocator_->AllocateRaw(alignment, total_size);
if (!mem) return nullptr;
diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.h b/tensorflow/contrib/tensorrt/resources/trt_allocator.h
index 6f94492083..dc9862b16c 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_allocator.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.h
@@ -29,7 +29,7 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
// std::align is not supported, so this function mimic its behavior.
-void* Align(size_t alignment, size_t size, void*& ptr, size_t& space);
+void* Align(uint64_t alignment, uint64_t size, void*& ptr, uint64_t& space);
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc b/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc
index f515ed03f2..ad6b1d7d4c 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc
+++ b/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc
@@ -20,11 +20,11 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
-bool RunTest(const size_t alignment, const size_t size,
- const intptr_t orig_ptr_val, const size_t orig_space) {
+bool RunTest(const uint64_t alignment, const uint64_t size,
+ const intptr_t orig_ptr_val, const uint64_t orig_space) {
void* const orig_ptr = reinterpret_cast<void*>(orig_ptr_val);
void* ptr = orig_ptr;
- size_t space = orig_space;
+ uint64_t space = orig_space;
void* result = Align(alignment, size, ptr, space);
if (result == nullptr) {
EXPECT_EQ(orig_ptr, ptr);
@@ -43,24 +43,25 @@ bool RunTest(const size_t alignment, const size_t size,
}
TEST(TRTAllocatorTest, Align) {
- for (const size_t space :
- {1, 2, 3, 4, 7, 8, 9, 10, 16, 32, 511, 512, 513, 700, 12345}) {
- for (size_t alignment = 1; alignment <= space * 4; alignment *= 2) {
- for (const intptr_t ptr_val :
+ for (const uint64_t space :
+ {1ul, 2ul, 3ul, 4ul, 7ul, 8ul, 9ul, 10ul, 16ul, 32ul, 511ul, 512ul,
+ 513ul, 700ul, 12345ul, 1ul << 32}) {
+ for (uint64_t alignment = 1; alignment <= space * 4; alignment *= 2) {
+ for (const uintptr_t ptr_val :
{1ul, alignment == 1 ? 1ul : alignment - 1, alignment, alignment + 1,
alignment + (alignment / 2)}) {
if (ptr_val % alignment == 0) {
- for (const size_t size :
+ for (const uint64_t size :
{1ul, space == 1 ? 1ul : space - 1, space, space + 1}) {
EXPECT_EQ(space >= size, RunTest(alignment, size, ptr_val, space));
}
} else {
EXPECT_FALSE(RunTest(alignment, space, ptr_val, space));
- const size_t diff = alignment - ptr_val % alignment;
+ const uint64_t diff = alignment - ptr_val % alignment;
if (space > diff) {
EXPECT_TRUE(
RunTest(alignment, space - diff, ptr_val + diff, space - diff));
- for (const size_t size :
+ for (const uint64_t size :
{1ul, space - diff > 1 ? space - diff - 1 : 1ul, space - diff,
space - diff + 1, space - 1}) {
EXPECT_EQ(space - diff >= size,
diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py
index e9ac833d55..7e9ffb05ab 100644
--- a/tensorflow/contrib/tensorrt/test/base_test.py
+++ b/tensorflow/contrib/tensorrt/test/base_test.py
@@ -183,6 +183,12 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
"my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"]
}
+ def ShouldRunTest(self, run_params):
+ """Whether to run the test."""
+ # Disable the test in fp16 mode since multiple matmul and add ops together
+ # can cause overflow.
+ return run_params.precision_mode != "FP16"
+
class PartiallyConvertedTestB(PartiallyConvertedTestA):
diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
index 62f4e525f7..d2f65344da 100644
--- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
+++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
@@ -144,14 +144,6 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase):
# mode, which is a bug. Re-enable this when trt library is fixed.
return not trt_test.IsQuantizationMode(run_params.precision_mode)
- def ExpectedAbsoluteTolerance(self, run_params):
- """The absolute tolerance to compare floating point results."""
- return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-03
-
- def ExpectedRelativeTolerance(self, run_params):
- """The relative tolerance to compare floating point results."""
- return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-03
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
index fc647e4eb9..4f935a7665 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
@@ -134,7 +134,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
dims[0] for dims in self._GetParamsCached().input_dims if len(dims)
]),
max_workspace_size_bytes=1 << 25,
- precision_mode=self._ToBytes(run_params.precision_mode),
+ precision_mode=run_params.precision_mode,
minimum_segment_size=2,
is_dynamic_op=run_params.dynamic_engine,
maximum_cached_engines=1,
@@ -179,11 +179,11 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
def ExpectedAbsoluteTolerance(self, run_params):
"""The absolute tolerance to compare floating point results."""
- return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-03
+ return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02
def ExpectedRelativeTolerance(self, run_params):
"""The relative tolerance to compare floating point results."""
- return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-03
+ return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02
def _GetParamsCached(self):
if self._trt_test_params is None:
@@ -414,6 +414,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
if not self.ShouldRunTest(run_params):
return
assert run_params.precision_mode in PRECISION_MODES
+ np.random.seed(12345)
params = self._GetParamsCached()
input_gdef = params.gdef
diff --git a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
index 84e36146d5..832d34d60d 100644
--- a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
+++ b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
@@ -63,7 +63,7 @@ class SkipGramOpsTest(test.TestCase):
(b"jumps", b"brown"),
(b"jumps", b"fox"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -94,7 +94,7 @@ class SkipGramOpsTest(test.TestCase):
(b"jumps", b"fox"),
(b"jumps", b"jumps"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -105,7 +105,7 @@ class SkipGramOpsTest(test.TestCase):
# If emit_self_as_target is False (default), output will be empty.
tokens, labels = text.skip_gram_sample(
input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(0, tokens.eval().size)
self.assertEqual(0, labels.eval().size)
@@ -117,7 +117,7 @@ class SkipGramOpsTest(test.TestCase):
(b"quick", b"quick"),
(b"brown", b"brown"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -134,7 +134,7 @@ class SkipGramOpsTest(test.TestCase):
(b"brown", b"the"),
(b"brown", b"quick"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -150,7 +150,7 @@ class SkipGramOpsTest(test.TestCase):
(b"quick", b"brown"),
(b"brown", b"quick"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -165,7 +165,7 @@ class SkipGramOpsTest(test.TestCase):
(b"quick", b"brown"),
(b"brown", b"quick"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -196,7 +196,7 @@ class SkipGramOpsTest(test.TestCase):
(b"over", b"fox"),
(b"over", b"jumps"),
])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens_eval, labels_eval = sess.run([tokens, labels])
self.assertAllEqual(expected_tokens, tokens_eval)
self.assertAllEqual(expected_labels, labels_eval)
@@ -222,7 +222,7 @@ class SkipGramOpsTest(test.TestCase):
tokens_2, labels_2 = text.skip_gram_sample(
input_tensor, min_skips=1, max_skips=5)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens_1_eval, labels_1_eval, tokens_2_eval, labels_2_eval = sess.run(
[tokens_1, labels_1, tokens_2, labels_2])
@@ -244,7 +244,7 @@ class SkipGramOpsTest(test.TestCase):
(b"brown", b"fox"),
(b"fox", b"brown"),
])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
@@ -269,7 +269,7 @@ class SkipGramOpsTest(test.TestCase):
(2, 3),
(3, 2),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -286,7 +286,7 @@ class SkipGramOpsTest(test.TestCase):
for min_skips, max_skips in invalid_skips:
tokens, labels = text.skip_gram_sample(
input_tensor, min_skips=min_skips, max_skips=max_skips)
- with self.test_session() as sess, self.assertRaises(
+ with self.cached_session() as sess, self.assertRaises(
errors.InvalidArgumentError):
sess.run([tokens, labels])
@@ -338,7 +338,7 @@ class SkipGramOpsTest(test.TestCase):
vocab_freq_table = lookup.HashTable(
lookup.KeyValueTensorInitializer(keys, values), -1)
- with self.test_session():
+ with self.cached_session():
vocab_freq_table.init.run()
# No vocab_freq_table specified - output should be the same as input.
@@ -395,7 +395,7 @@ class SkipGramOpsTest(test.TestCase):
vocab_freq_table = lookup.HashTable(
lookup.KeyValueTensorInitializer(keys, values), -1)
- with self.test_session():
+ with self.cached_session():
vocab_freq_table.init.run()
output = skip_gram_ops._filter_input(
input_tensor=input_tensor,
@@ -464,7 +464,7 @@ class SkipGramOpsTest(test.TestCase):
(b"life", b"and"),
(b"and", b"life"),
])
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -510,7 +510,7 @@ class SkipGramOpsTest(test.TestCase):
(b"to", b"life"),
(b"life", b"to"),
])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
lookup_ops.tables_initializer().run()
tokens_eval, labels_eval = sess.run([tokens, labels])
self.assertAllEqual(expected_tokens, tokens_eval)
diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD
index 21c0c30c19..57797214d1 100644
--- a/tensorflow/contrib/timeseries/examples/BUILD
+++ b/tensorflow/contrib/timeseries/examples/BUILD
@@ -1,4 +1,5 @@
load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow:tensorflow.bzl", "py_binary")
package(
default_visibility = ["//tensorflow:internal"],
diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
index 1d27fffc62..9bbe87e301 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
@@ -191,6 +191,43 @@ class ARModel(model.TimeSeriesModel):
Note that this class can also be used to regress against time only by setting
the input_window_size to zero.
+
+ Each periodicity in the `periodicities` arg is divided by the
+ `num_time_buckets` into time buckets that are represented as features added
+ to the model.
+
+ A good heuristic for picking an appropriate periodicity for a given data set
+ would be the length of cycles in the data. For example, energy usage in a
+ home is typically cyclic each day. If the time feature in a home energy
+ usage dataset is in the unit of hours, then 24 would be an appropriate
+ periodicity. Similarly, a good heuristic for `num_time_buckets` is how often
+ the data is expected to change within the cycle. For the aforementioned home
+ energy usage dataset and periodicity of 24, then 48 would be a reasonable
+ value if usage is expected to change every half hour.
+
+ Each feature's value for a given example with time t is the difference
+ between t and the start of the time bucket it falls under. If it doesn't fall
+ under a feature's associated time bucket, then that feature's value is zero.
+
+ For example: if `periodicities` = (9, 12) and `num_time_buckets` = 3, then 6
+ features would be added to the model, 3 for periodicity 9 and 3 for
+ periodicity 12.
+
+ For an example data point where t = 17:
+ - It's in the 3rd time bucket for periodicity 9 (2nd period is 9-18 and 3rd
+ time bucket is 15-18)
+ - It's in the 2nd time bucket for periodicity 12 (2nd period is 12-24 and
+ 2nd time bucket is between 16-20).
+
+ Therefore the 6 added features for this row with t = 17 would be:
+
+ # Feature name (periodicity#_timebucket#), feature value
+ P9_T1, 0 # not in first time bucket
+ P9_T2, 0 # not in second time bucket
+ P9_T3, 2 # 17 - 15 since 15 is the start of the 3rd time bucket
+ P12_T1, 0 # not in first time bucket
+ P12_T2, 1 # 17 - 16 since 16 is the start of the 2nd time bucket
+ P12_T3, 0 # not in third time bucket
"""
SQUARED_LOSS = "squared_loss"
NORMAL_LIKELIHOOD_LOSS = "normal_likelihood_loss"
@@ -208,7 +245,9 @@ class ARModel(model.TimeSeriesModel):
Args:
periodicities: periodicities of the input data, in the same units as the
- time feature. Note this can be a single value or a list of values for
+ time feature (for example 24 if feeding hourly data with a daily
+ periodicity, or 60 * 24 if feeding minute-level data with daily
+ periodicity). Note this can be a single value or a list of values for
multiple periodicities.
input_window_size: Number of past time steps of data to look at when doing
the regression.
@@ -218,21 +257,18 @@ class ARModel(model.TimeSeriesModel):
prediction_model_factory: A callable taking arguments `num_features`,
`input_window_size`, and `output_window_size` and returning a
`tf.keras.Model`. The `Model`'s `call()` takes two arguments: an input
- window and an output window, and returns a dictionary of
- predictions. See `FlatPredictionModel` for an example. Example usage:
+ window and an output window, and returns a dictionary of predictions.
+ See `FlatPredictionModel` for an example. Example usage:
- ```python
- model = ar_model.ARModel(
- periodicities=2, num_features=3,
- prediction_model_factory=functools.partial(
- FlatPredictionModel,
- hidden_layer_sizes=[10, 10]))
- ```
+ ```python model = ar_model.ARModel( periodicities=2, num_features=3,
+ prediction_model_factory=functools.partial( FlatPredictionModel,
+ hidden_layer_sizes=[10, 10])) ```
The default model computes predictions as a linear function of flattened
input and output windows.
num_time_buckets: Number of buckets into which to divide (time %
- periodicity) for generating time based features.
+ periodicity). This value multiplied by the number of periodicities is
+ the number of time features added to the model.
loss: Loss function to use for training. Currently supported values are
SQUARED_LOSS and NORMAL_LIKELIHOOD_LOSS. Note that for
NORMAL_LIKELIHOOD_LOSS, we train the covariance term as well. For
@@ -240,10 +276,9 @@ class ARModel(model.TimeSeriesModel):
observations and predictions, while the training loss is computed on
normalized data (if input statistics are available).
exogenous_feature_columns: A list of `tf.feature_column`s (for example
- `tf.feature_column.embedding_column`) corresponding to exogenous
- features which provide extra information to the model but are not part
- of the series to be predicted. Passed to
- `tf.feature_column.input_layer`.
+ `tf.feature_column.embedding_column`) corresponding to
+ features which provide extra information to the model but are not part
+ of the series to be predicted.
"""
self._model_factory = prediction_model_factory
self.input_window_size = input_window_size
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
index 0ddc4b4144..af68aa03cf 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
@@ -30,6 +30,7 @@ from tensorflow.contrib.timeseries.python.timeseries.state_space_models import s
from tensorflow.contrib.timeseries.python.timeseries.state_space_models.filtering_postprocessor import StateInterpolatingAnomalyDetector
from tensorflow.python.estimator import estimator_lib
+from tensorflow.python.estimator.canned import optimizers
from tensorflow.python.estimator.export import export_lib
from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import dtypes
@@ -386,6 +387,162 @@ class ARRegressor(TimeSeriesRegressor):
config=config)
+# TODO(b/113684821): Add detailed documentation on what the input_fn should do.
+# Add an example of making and returning a Dataset object. Determine if
+# endogenous features can be passed in as FeatureColumns. Move ARModel's loss
+# functions into a more general location.
+class LSTMAutoRegressor(TimeSeriesRegressor):
+ """An Estimator for an LSTM autoregressive model.
+
+ LSTMAutoRegressor is a window-based model, inputting fixed windows of length
+ `input_window_size` and outputting fixed windows of length
+ `output_window_size`. These two parameters must add up to the window_size
+ of data returned by the `input_fn`.
+
+ Each periodicity in the `periodicities` arg is divided by the `num_timesteps`
+ into timesteps that are represented as time features added to the model.
+
+ A good heuristic for picking an appropriate periodicity for a given data set
+ would be the length of cycles in the data. For example, energy usage in a
+ home is typically cyclic each day. If the time feature in a home energy
+ usage dataset is in the unit of hours, then 24 would be an appropriate
+ periodicity. Similarly, a good heuristic for `num_timesteps` is how often the
+ data is expected to change within the cycle. For the aforementioned home
+ energy usage dataset and periodicity of 24, then 48 would be a reasonable
+ value if usage is expected to change every half hour.
+
+ Each feature's value for a given example with time t is the difference
+ between t and the start of the timestep it falls under. If it doesn't fall
+ under a feature's associated timestep, then that feature's value is zero.
+
+ For example: if `periodicities` = (9, 12) and `num_timesteps` = 3, then 6
+ features would be added to the model, 3 for periodicity 9 and 3 for
+ periodicity 12.
+
+ For an example data point where t = 17:
+ - It's in the 3rd timestep for periodicity 9 (2nd period is 9-18 and 3rd
+ timestep is 15-18)
+ - It's in the 2nd timestep for periodicity 12 (2nd period is 12-24 and
+ 2nd timestep is between 16-20).
+
+ Therefore the 6 added features for this row with t = 17 would be:
+
+ # Feature name (periodicity#_timestep#), feature value
+ P9_T1, 0 # not in first timestep
+ P9_T2, 0 # not in second timestep
+ P9_T3, 2 # 17 - 15 since 15 is the start of the 3rd timestep
+ P12_T1, 0 # not in first timestep
+ P12_T2, 1 # 17 - 16 since 16 is the start of the 2nd timestep
+ P12_T3, 0 # not in third timestep
+
+ Example Code:
+
+ ```python
+ extra_feature_columns = (
+ feature_column.numeric_column("exogenous_variable"),
+ )
+
+ estimator = LSTMAutoRegressor(
+ periodicities=10,
+ input_window_size=10,
+ output_window_size=5,
+ model_dir="/path/to/model/dir",
+ num_features=1,
+ extra_feature_columns=extra_feature_columns,
+ num_timesteps=50,
+ num_units=10,
+ optimizer=tf.train.ProximalAdagradOptimizer(...))
+
+ # Input builders
+ def input_fn_train():
+ return {
+ "times": tf.range(15)[None, :],
+ "values": tf.random_normal(shape=[1, 15, 1])
+ }
+ estimator.train(input_fn=input_fn_train, steps=100)
+
+ def input_fn_eval():
+ pass
+ metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
+
+ def input_fn_predict():
+ pass
+ predictions = estimator.predict(input_fn=input_fn_predict)
+ ```
+ """
+
+ def __init__(self,
+ periodicities,
+ input_window_size,
+ output_window_size,
+ model_dir=None,
+ num_features=1,
+ extra_feature_columns=None,
+ num_timesteps=10,
+ loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS,
+ num_units=128,
+ optimizer="Adam",
+ config=None):
+ """Initialize the Estimator.
+
+ Args:
+ periodicities: periodicities of the input data, in the same units as the
+ time feature (for example 24 if feeding hourly data with a daily
+ periodicity, or 60 * 24 if feeding minute-level data with daily
+ periodicity). Note this can be a single value or a list of values for
+ multiple periodicities.
+ input_window_size: Number of past time steps of data to look at when doing
+ the regression.
+ output_window_size: Number of future time steps to predict. Note that
+ setting this value to > 1 empirically seems to give a better fit.
+ model_dir: Directory to save model parameters, graph and etc. This can
+ also be used to load checkpoints from the directory into a estimator
+ to continue training a previously saved model.
+ num_features: The dimensionality of the time series (default value is
+ one for univariate, more than one for multivariate).
+ extra_feature_columns: A list of `tf.feature_column`s (for example
+ `tf.feature_column.embedding_column`) corresponding to features which
+ provide extra information to the model but are not part of the series to
+ be predicted.
+ num_timesteps: Number of buckets into which to divide (time %
+ periodicity). This value multiplied by the number of periodicities is
+ the number of time features added to the model.
+ loss: Loss function to use for training. Currently supported values are
+ SQUARED_LOSS and NORMAL_LIKELIHOOD_LOSS. Note that for
+ NORMAL_LIKELIHOOD_LOSS, we train the covariance term as well. For
+ SQUARED_LOSS, the evaluation loss is reported based on un-scaled
+ observations and predictions, while the training loss is computed on
+ normalized data.
+ num_units: The size of the hidden state in the encoder and decoder LSTM
+ cells.
+ optimizer: string, `tf.train.Optimizer` object, or callable that defines
+ the optimizer algorithm to use for training. Defaults to the Adam
+ optimizer with a learning rate of 0.01.
+ config: Optional `estimator.RunConfig` object to configure the runtime
+ settings.
+ """
+ optimizer = optimizers.get_optimizer_instance(
+ optimizer, learning_rate=0.01)
+ model = ar_model.ARModel(
+ periodicities=periodicities,
+ input_window_size=input_window_size,
+ output_window_size=output_window_size,
+ num_features=num_features,
+ exogenous_feature_columns=extra_feature_columns,
+ num_time_buckets=num_timesteps,
+ loss=loss,
+ prediction_model_factory=functools.partial(
+ ar_model.LSTMPredictionModel, num_units=num_units))
+ state_manager = state_management.FilteringOnlyStateManager()
+ super(LSTMAutoRegressor, self).__init__(
+ model=model,
+ state_manager=state_manager,
+ optimizer=optimizer,
+ model_dir=model_dir,
+ config=config,
+ head_type=ts_head_lib.OneShotPredictionHead)
+
+
class StateSpaceRegressor(TimeSeriesRegressor):
"""An Estimator for general state space models."""
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
index 83260fc59a..6ec7184c68 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
@@ -226,5 +226,40 @@ class TimeSeriesRegressorTest(test.TestCase):
input_pipeline.NumpyReader(numpy_data)),
steps=1)
+ def test_ar_lstm_regressor(self):
+ dtype = dtypes.float32
+ model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
+ exogenous_feature_columns = (
+ feature_column.numeric_column("exogenous"),
+ )
+ estimator = estimators.LSTMAutoRegressor(
+ periodicities=10,
+ input_window_size=10,
+ output_window_size=6,
+ model_dir=model_dir,
+ num_features=1,
+ extra_feature_columns=exogenous_feature_columns,
+ num_units=10,
+ config=_SeedRunConfig())
+ times = numpy.arange(20, dtype=numpy.int64)
+ values = numpy.arange(20, dtype=dtype.as_numpy_dtype)
+ exogenous = numpy.arange(20, dtype=dtype.as_numpy_dtype)
+ features = {
+ feature_keys.TrainEvalFeatures.TIMES: times,
+ feature_keys.TrainEvalFeatures.VALUES: values,
+ "exogenous": exogenous
+ }
+ train_input_fn = input_pipeline.RandomWindowInputFn(
+ input_pipeline.NumpyReader(features), shuffle_seed=2, num_threads=1,
+ batch_size=16, window_size=16)
+ eval_input_fn = input_pipeline.RandomWindowInputFn(
+ input_pipeline.NumpyReader(features), shuffle_seed=3, num_threads=1,
+ batch_size=16, window_size=16)
+ estimator.train(input_fn=train_input_fn, steps=1)
+ evaluation = estimator.evaluate(
+ input_fn=eval_input_fn, steps=1)
+ self.assertAllEqual(evaluation["loss"], evaluation["average_loss"])
+ self.assertAllEqual([], evaluation["loss"].shape)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py
index 951c6546d5..d04c721007 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py
@@ -909,7 +909,7 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
elif unbroadcasted_shape.ndims == 2:
# Unbroadcasted shape [num features x state dimension]
broadcasted_model = array_ops.tile(
- array_ops.expand_dims(unbroadcasted_model, dim=0),
+ array_ops.expand_dims(unbroadcasted_model, axis=0),
[array_ops.shape(times)[0], 1, 1])
elif unbroadcasted_shape.ndims == 3:
broadcasted_model = unbroadcasted_model
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 298ffc1ded..e9aa037634 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -36,6 +36,27 @@ cc_library(
)
py_library(
+ name = "async_checkpoint",
+ srcs = ["python/tpu/async_checkpoint.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:summary_ops_v2",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+)
+
+py_library(
name = "tpu_estimator",
srcs = [
"python/tpu/error_handling.py",
@@ -46,6 +67,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
+ ":async_checkpoint",
":tpu_lib",
"//tensorflow/compiler/xla/experimental/xla_sharding",
"//tensorflow/compiler/xla/python_api:xla_shape",
@@ -80,7 +102,10 @@ tf_gen_op_libs(
"tpu_embedding_ops",
],
deps = [
- "//tensorflow/contrib/tpu/proto:tpu_embedding_config_proto_cc",
+ "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc",
+ "//tensorflow/contrib/tpu/utils:tpu_embedding_optimization_parameters_utils",
+ "//tensorflow/contrib/tpu/utils:tpu_embedding_output_layout_utils",
+ "//tensorflow/core:lib",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:protos_all_cc",
],
@@ -99,7 +124,9 @@ tf_custom_op_library(
"ops/tpu_embedding_ops.cc",
],
deps = [
- "//tensorflow/contrib/tpu/proto:tpu_embedding_config_proto_cc",
+ "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc",
+ "//tensorflow/contrib/tpu/utils:tpu_embedding_optimization_parameters_utils",
+ "//tensorflow/contrib/tpu/utils:tpu_embedding_output_layout_utils",
"//tensorflow/core:lib_proto_parsing",
],
)
@@ -225,7 +252,10 @@ py_library(
":tpu_py",
"//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py",
"//tensorflow/contrib/tpu/proto:compilation_result_proto_py",
+ "//tensorflow/contrib/tpu/proto:optimization_parameters_proto_py",
"//tensorflow/contrib/tpu/proto:topology_proto_py",
+ "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_py",
+ "//tensorflow/contrib/tpu/proto:tpu_embedding_output_layout_proto_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
@@ -351,7 +381,7 @@ tf_py_test(
tf_py_test(
name = "topology_test",
- size = "small",
+ size = "medium",
srcs = ["python/tpu/topology_test.py"],
additional_deps = [
":tpu",
diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py
index 3c0456dc2f..766466968a 100644
--- a/tensorflow/contrib/tpu/__init__.py
+++ b/tensorflow/contrib/tpu/__init__.py
@@ -55,6 +55,7 @@
@@TPUDistributionStrategy
@@keras_to_tpu_model
+@@AsyncCheckpointSaverHook
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
index ea8e0e00ed..87e3a5946c 100644
--- a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
+++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
@@ -125,4 +125,24 @@ output: The sum of all the distributed inputs.
T: The type of elements to be summed.
)doc");
+REGISTER_OP("CollectivePermute")
+ .Input("input: T")
+ .Input("source_target_pairs: int32")
+ .Output("output: T")
+ .Attr("T: numbertype")
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+An Op to permute tensors across replicated TPU instances. Each instance
+supplies its own input.
+
+For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
+source_target_pairs=`[[0,1],[1,2],[2,3],[3,0]]` gets the outputs:
+`[D, A, B, C]`.
+
+input: The local input to be permuted. Currently only supports float and
+ bfloat16.
+source_target_pairs: A tensor with shape [num_pairs, 2].
+output: The permuted input.
+T: The type of elements to be exchanged.
+)doc");
} // namespace tensorflow
diff --git a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
index 72d37f774c..1bd1a31e11 100644
--- a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
+++ b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
@@ -13,11 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/tpu/proto/tpu_embedding_config.pb.h"
+#include "tensorflow/contrib/tpu/proto/tpu_embedding_configuration.pb.h"
+#include "tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.h"
+#include "tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
namespace tensorflow {
@@ -53,215 +58,339 @@ namespace tensorflow {
// saving a checkpoint, the model must Retrieve the parameters back into the
// host CPU memory.
-REGISTER_OP("TPUEmbeddingLoadGradientDescentParameters")
- .Input("parameters: float32")
- .Attr("tpu_embedding_config: string")
- .Attr("table_id: int >= 0")
- .Attr("num_hosts: int >= 1")
- .Attr("host_id: int >= 0")
- .SetIsStateful()
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Load an embedding table shard into TPU memory for use with GradientDescent.
-
-TPU embeddings use dedicated per-optimizer Ops for loading and retrieving
-trainable variables and optimizer state from TPU memory. This op enables
-functionality equivalent to GradientDescentOptimizer.
-
-parameters: The shard of the embedding table resident on the host executing this
- op. For single-TPU models, this is the entire embedding table.
-tpu_embedding_config: Serialized TPUEmbeddingConfiguration proto.
-table_id: The id of the table specified in the tpu_embedding_config.
-num_hosts: The number of CPU hosts in the distributed training job.
-host_id: Which CPU host in the distributed training job will execute this op.
-)doc");
+namespace {
-namespace tpu_embedding_config_util {
+void RegisterPerTableLoadAndRetrieveOps();
-Status GradientDescentShapes(shape_inference::InferenceContext *c) {
- string config_string;
- TF_RETURN_IF_ERROR(c->GetAttr("tpu_embedding_config", &config_string));
- tpu::TPUEmbeddingConfiguration config;
- if (!config.ParseFromString(config_string)) {
- return errors::InvalidArgument("Malformed tpu_embedding_config.");
+class RegisterPerTableLoadAndRetrieveOpsOnConstruction {
+ public:
+ RegisterPerTableLoadAndRetrieveOpsOnConstruction() {
+ RegisterPerTableLoadAndRetrieveOps();
}
-
- int table_id;
- TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
- int64 num_tables = config.table_config_size();
- if (table_id >= num_tables) {
- return errors::InvalidArgument("Table id >= num_tables");
+};
+
+// Object whose constructor does registrations.
+RegisterPerTableLoadAndRetrieveOpsOnConstruction
+ register_per_table_load_and_retrieve_ops_var;
+
+Status RegisterPerTableLoadOpsForAlgorithmBody(
+ tpu::OptimizationAlgorithm alg, bool is_debug_op,
+ OpRegistrationData* op_reg_data) {
+ tpu::GradientAccumulationSupport grad_accum_support;
+ TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support));
+
+ std::vector<tpu::StateVariableSpecification> state_variable_specs;
+ TF_CHECK_OK(GetOptimizationAlgorithmStateVariables(
+ alg,
+ grad_accum_support == tpu::GradientAccumulationSupport::kSupported &&
+ is_debug_op,
+ &state_variable_specs));
+ auto* op_def = &op_reg_data->op_def;
+ op_def->set_name(
+ strings::StrCat("LoadTPUEmbedding", GetOptimizationAlgorithmName(alg),
+ "Parameters", (is_debug_op ? "GradAccumDebug" : "")));
+ // It is important for the order of the inputs to the op defined here
+ // to match the order in input_names because the indexes are used in
+ // the combining transformation.
+ for (const auto& parameter : state_variable_specs) {
+ if (parameter.has_user_defined() || is_debug_op) {
+ auto* arg = op_def->add_input_arg();
+ arg->set_name(parameter.name());
+ arg->set_description(
+ strings::StrCat("Value of ", parameter.name(), " used in the ",
+ GetOptimizationAlgorithmFriendlyName(alg),
+ " optimization algorithm."));
+ arg->set_type(DT_FLOAT);
+ }
}
- int64 width = config.table_config(table_id).width();
- int64 num_rows = config.table_config(table_id).num_rows();
-
- TF_RETURN_IF_ERROR(c->set_output("parameters", {c->Matrix(num_rows, width)}));
+ {
+ auto* table_id_attr = op_def->add_attr();
+ table_id_attr->set_name("table_id");
+ table_id_attr->set_type("int");
+ table_id_attr->set_has_minimum(true);
+ table_id_attr->set_minimum(-1);
+ table_id_attr->mutable_default_value()->set_i(-1);
+ }
+ {
+ auto* table_name_attr = op_def->add_attr();
+ table_name_attr->set_name("table_name");
+ table_name_attr->set_type("string");
+ table_name_attr->mutable_default_value()->set_s("");
+ }
+ {
+ auto* num_shards_attr = op_def->add_attr();
+ num_shards_attr->set_name("num_shards");
+ num_shards_attr->set_type("int");
+ }
+ {
+ auto* shard_id_attr = op_def->add_attr();
+ shard_id_attr->set_name("shard_id");
+ shard_id_attr->set_type("int");
+ }
+ op_def->set_summary("Load embedding parameters for a single table.");
+ string parameter_descriptions;
+ for (const auto& parameter : state_variable_specs) {
+ if (parameter.has_user_defined() || is_debug_op) {
+ strings::Appendf(&parameter_descriptions,
+ R"(
+%s: A tensor containing the initial embedding table %s to use in embedding
+lookups using the %s optimization algorithm.)",
+ parameter.name().c_str(), parameter.name().c_str(),
+ GetOptimizationAlgorithmFriendlyName(alg).c_str());
+ }
+ }
+ op_def->set_description(strings::Printf(R"doc(
+An op that loads optimization parameters into HBM for embedding. Must be
+preceded by a ConfigureTPUEmbeddingHost op that sets up the correct
+embedding table configuration. For example, this op is used to install
+parameters that are loaded from a checkpoint before a training loop is
+executed.
+%s
+table_name: Name of this table; must match a name in the
+ EmbeddingLayerConfiguration proto (overrides table_id).
+num_shards: Number of shards into which the embedding tables are divided.
+shard_id: Identifier of shard for this operation.
+table_id: Index of this table in the EmbeddingLayerConfiguration proto
+ (deprecated).
+)doc",
+ parameter_descriptions.c_str()));
+ op_def->set_is_commutative(false);
+ op_def->set_is_aggregate(false);
+ op_def->set_is_stateful(true);
+ auto shape_inference_function =
+ [state_variable_specs,
+ is_debug_op](shape_inference::InferenceContext* c) -> Status {
+ int table_id;
+ TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
+ string table_name;
+ TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
+ // Exactly one must be non-default.
+ if ((table_id >= 0) == (!table_name.empty())) {
+ return errors::InvalidArgument(
+ "exactly one of table_id or table_name must be non-default");
+ }
+ int num_shards;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
+ int shard_id;
+ TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
+ const int user_param_count =
+ std::count_if(state_variable_specs.begin(), state_variable_specs.end(),
+ [&](const tpu::StateVariableSpecification& sv) {
+ return sv.has_user_defined() || is_debug_op;
+ });
+ std::vector<shape_inference::ShapeHandle> inputs(user_param_count);
+ int input_index = 0;
+ for (int i = 0; i < state_variable_specs.size(); ++i) {
+ if (state_variable_specs[i].has_user_defined() || is_debug_op) {
+ std::vector<shape_inference::ShapeHandle> input_temp;
+ TF_RETURN_IF_ERROR(
+ c->input(state_variable_specs[i].name(), &input_temp));
+ if (input_temp.size() != 1) {
+ return errors::InvalidArgument("each input to be rank 1");
+ }
+ inputs[input_index] = input_temp[0];
+ ++input_index;
+ }
+ }
+ // Verify shapes have rank 2 and are compatible when they are
+ // required to be valid.
+ shape_inference::ShapeHandle parameter_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(inputs[0], 2, &parameter_shape));
+ for (int j = 1; j < user_param_count; ++j) {
+ shape_inference::ShapeHandle accumulator_j_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(inputs[j], 2, &accumulator_j_shape));
+ shape_inference::ShapeHandle merged;
+ TF_RETURN_IF_ERROR(
+ c->Merge(parameter_shape, accumulator_j_shape, &merged));
+ }
+ return Status::OK();
+ };
+ op_reg_data->shape_inference_fn = shape_inference_function;
return Status::OK();
}
-} // namespace tpu_embedding_config_util
-
-REGISTER_OP("TPUEmbeddingRetrieveGradientDescentParameters")
- .Output("parameters: float32")
- .Attr("tpu_embedding_config: string")
- .Attr("table_id: int")
- .Attr("num_hosts: int")
- .Attr("host_id: int")
- .SetIsStateful()
- .SetShapeFn(tpu_embedding_config_util::GradientDescentShapes)
- .Doc(R"doc(
-Retrieve an embedding table shard from TPU memory.
-
-TPU embeddings use dedicated per-optimizer Ops for loading and retrieving
-trainable variables and optimizer state from TPU memory. This op enables
-functionality equivalent to GradientDescentOptimizer.
-
-tpu_embedding_config: Serialized TPUEmbeddingConfiguration proto.
-table_id: The id of the table specified in tpu_embedding_config.
-num_hosts: The number of CPU hosts in the distributed training job.
-host_id: Which CPU host in the distributed training job will execute this op.
-)doc");
-
-REGISTER_OP("TPUEmbeddingLoadAdagradParameters")
- .Input("parameters: float32")
- .Input("accumulators: float32")
- .Attr("tpu_embedding_config: string")
- .Attr("table_id: int >= 0")
- .Attr("num_hosts: int >= 1")
- .Attr("host_id: int >= 0")
- .SetIsStateful()
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Load an embedding table shard into TensorNode memories for use with Adagrad.
-
-TPU embeddings use dedicated per-optimizer Ops for loading and retrieving
-trainable variables and optimizer state from TPU memory. This op enables
-functionality equivalent to AdagradOptimizer.
-
-parameters: The shard of the embedding table resident on the host executing this
- op. For single-TPU models, this is the entire embedding table.
-accumulators: Shard of the Adagrad accumulators resident on the host executing
- this op.
-tpu_embedding_config: Serialized TPUEmbeddingConfiguration proto.
-table_id: The id of the table specified in the embedding_config.
-num_hosts: The number of CPU hosts in the distributed training job.
-host_id: Which CPU host in the distributed training job will execute this op.
-)doc");
-
-namespace tpu_embedding_config_util {
-
-Status AdagradShapes(shape_inference::InferenceContext *c) {
- string config_string;
- TF_RETURN_IF_ERROR(c->GetAttr("tpu_embedding_config", &config_string));
- tpu::TPUEmbeddingConfiguration config;
- if (!config.ParseFromString(config_string)) {
- return errors::InvalidArgument("Malformed tpu_embedding_config.");
+Status RegisterPerTableRetrieveOpsForAlgorithmBody(
+ tpu::OptimizationAlgorithm alg, bool is_debug_op,
+ OpRegistrationData* op_reg_data) {
+ tpu::GradientAccumulationSupport grad_accum_support;
+ TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support));
+
+ std::vector<tpu::StateVariableSpecification> state_variable_specs;
+ TF_CHECK_OK(GetOptimizationAlgorithmStateVariables(
+ alg,
+ grad_accum_support == tpu::GradientAccumulationSupport::kSupported &&
+ is_debug_op,
+ &state_variable_specs));
+
+ auto* op_def = &op_reg_data->op_def;
+ op_def->set_name(strings::StrCat(
+ "RetrieveTPUEmbedding", tpu::GetOptimizationAlgorithmName(alg),
+ "Parameters", (is_debug_op ? "GradAccumDebug" : "")));
+ // It is important for the order of the outputs of the op defined here
+ // to match the order in output_names because the indexes are used in
+ // the combining transformation.
+ for (const auto& parameter : state_variable_specs) {
+ if (parameter.has_user_defined() || is_debug_op) {
+ auto* arg = op_def->add_output_arg();
+ arg->set_name(parameter.name());
+ arg->set_description(
+ strings::StrCat("Parameter ", parameter.name(), " updated by the ",
+ tpu::GetOptimizationAlgorithmFriendlyName(alg),
+ " optimization algorithm."));
+ arg->set_type(DT_FLOAT);
+ }
}
-
- int table_id;
- TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
- int64 num_tables = config.table_config_size();
- if (table_id >= num_tables) {
- return errors::InvalidArgument("Table id >= num_tables");
+ {
+ auto* table_id_attr = op_def->add_attr();
+ table_id_attr->set_name("table_id");
+ table_id_attr->set_type("int");
+ table_id_attr->set_has_minimum(true);
+ table_id_attr->set_minimum(-1);
+ table_id_attr->mutable_default_value()->set_i(-1);
}
- int64 width = config.table_config(table_id).width();
- int64 num_rows = config.table_config(table_id).num_rows();
-
- TF_RETURN_IF_ERROR(c->set_output("parameters", {c->Matrix(num_rows, width)}));
- TF_RETURN_IF_ERROR(
- c->set_output("accumulators", {c->Matrix(num_rows, width)}));
+ {
+ auto* table_name_attr = op_def->add_attr();
+ table_name_attr->set_name("table_name");
+ table_name_attr->set_type("string");
+ table_name_attr->mutable_default_value()->set_s("");
+ }
+ {
+ auto* num_shards_attr = op_def->add_attr();
+ num_shards_attr->set_name("num_shards");
+ num_shards_attr->set_type("int");
+ }
+ {
+ auto* shard_id_attr = op_def->add_attr();
+ shard_id_attr->set_name("shard_id");
+ shard_id_attr->set_type("int");
+ }
+ op_def->set_summary("Retrieve embedding parameters for a single table.");
+ string parameter_descriptions;
+ for (const auto& param : state_variable_specs) {
+ if (param.has_user_defined() || is_debug_op) {
+ strings::Appendf(&parameter_descriptions,
+ R"(
+%s: A tensor containing the embedding table %s to store with the
+parameters from embedding updates using the %s optimization algorithm.)",
+ param.name().c_str(), param.name().c_str(),
+ tpu::GetOptimizationAlgorithmFriendlyName(alg).c_str());
+ }
+ }
+ op_def->set_description(strings::Printf(R"doc(
+An op that retrieves optimization parameters from embedding to host
+memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up
+the correct embedding table configuration. For example, this op is
+used to retrieve updated parameters before saving a checkpoint.
+%s
+table_name: Name of this table; must match a name in the
+ EmbeddingLayerConfiguration proto (overrides table_id).
+num_shards: Number of shards into which the embedding tables are divided.
+shard_id: Identifier of shard for this operation.
+table_id: Index of this table in the EmbeddingLayerConfiguration proto
+ (deprecated).
+)doc",
+ parameter_descriptions.c_str()));
+ op_def->set_is_commutative(false);
+ op_def->set_is_aggregate(false);
+ op_def->set_is_stateful(true);
+ auto shape_inference_function =
+ [state_variable_specs,
+ is_debug_op](shape_inference::InferenceContext* c) -> Status {
+ int table_id;
+ TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
+ string table_name;
+ TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
+ // Exactly one must be non-default.
+ if ((table_id >= 0) == (!table_name.empty())) {
+ return errors::InvalidArgument(
+ "exactly one of table_id or table_name must be non-default");
+ }
+ int num_shards;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
+ int shard_id;
+ TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
+ for (int j = 0; j < state_variable_specs.size(); ++j) {
+ if (state_variable_specs[j].has_user_defined() || is_debug_op) {
+ auto shape = c->MakeShape(
+ std::vector<shape_inference::DimensionHandle>(2, c->UnknownDim()));
+ TF_RETURN_IF_ERROR(
+ c->set_output(state_variable_specs[j].name(),
+ std::vector<shape_inference::ShapeHandle>(1, shape)));
+ }
+ }
+ return Status::OK();
+ };
+ op_reg_data->shape_inference_fn = shape_inference_function;
return Status::OK();
}
-} // namespace tpu_embedding_config_util
-
-REGISTER_OP("TPUEmbeddingRetrieveAdagradParameters")
- .Output("parameters: float32")
- .Output("accumulators: float32")
- .Attr("tpu_embedding_config: string")
- .Attr("table_id: int >= 0")
- .Attr("num_hosts: int >= 1")
- .Attr("host_id: int >= 0")
- .SetIsStateful()
- .SetShapeFn(tpu_embedding_config_util::AdagradShapes)
- .Doc(R"doc(
-Retrieve an embedding table shard from TPU memory.
-
-TPU embeddings use dedicated per-optimizer Ops for loading and retrieving
-trainable variables and optimizer state from TPU memory. This op enables
-functionality equivalent to AdagradOptimizer.
-
-tpu_embedding_config: Serialized TPUEmbeddingConfiguration proto.
-table_id: The id of the table specified in the embedding_config_json.
-num_hosts: The number of CPU hosts in the distributed training job.
-host_id: Which CPU host in the distributed training job will execute this op.
-)doc");
-
-REGISTER_OP("TPUEmbeddingEnqueueSparseBatch")
- .Input("sample_indices: num_tables * int32")
- .Input("embedding_indices: num_tables * int32")
- .Input("aggregation_weights: num_tables * float32")
- .Attr("num_tables: int")
- .Attr("device_ordinal: int = -1")
- .SetIsStateful()
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-An op that feeds a batch of embedding indices and weights to the TPU.
-
-Embedding lookups are equivalent to sparse-dense matrix multiplications: the
-sparse matrix contains nonzeros in column j in order to retrieve row j from the
-embedding table.
-
-The three Tensor list arguments (sample_indices, embedding_indices, and
-aggregation_weights) represent these sparse matrices in COO format. The Tensor
-lists each have one entry for each embedding table specified in the model.
-For the kth embedding table, the three Tensors at position k in the list
-specify a COO-format sparse matrix. For the kth table, the row indices,
-column indices, and nonzero values of the COO sparse matrix are specified by
-sample_indices[k], embedding_indices[k], and aggregation_weights[k],
-respectively. Entries must be sorted by row index, then by column index.
-
-There should be at most one TPUEmbeddingEnqueueSparseBatch op in a signle
-training step per TPU shard.
-
-sample_indices: A list of rank 1 Tensors specifying row indices of the COO
- sparse matrix representing the embedding lookups for each table.
-embedding_indices: A list of rank 1 Tensors specifying column indices of the
- COO sparse matrix representing the embedding lookups for each table.
-aggregation_weights: A list of rank 1 Tensors specifying the nonzero values
- of the COO sparse matrix representing the embedding lookups for each table.
-device_ordinal: The TPU device to use. This should be -1 when the Op
- is running on a TPU device, and >= 0 when the Op is running on the CPU
- device.
-)doc");
-
-namespace tpu_embedding_config_util {
-
-Status ActivationShapes(shape_inference::InferenceContext *c) {
- string config_string;
- TF_RETURN_IF_ERROR(c->GetAttr("tpu_embedding_config", &config_string));
- tpu::TPUEmbeddingConfiguration config;
- if (!config.ParseFromString(config_string)) {
- return errors::InvalidArgument("Malformed tpu_embedding_config.");
+void RegisterPerTableLoadAndRetrieveOps() {
+ // Load ops
+ for (tpu::OptimizationAlgorithm alg : tpu::GetOptimizationAlgorithms()) {
+ OpRegistry::Global()->Register(
+ [alg](OpRegistrationData* op_reg_data) -> Status {
+ return RegisterPerTableLoadOpsForAlgorithmBody(alg, false,
+ op_reg_data);
+ });
+ tpu::GradientAccumulationSupport grad_accum_support;
+ TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support));
+ if (grad_accum_support == tpu::GradientAccumulationSupport::kSupported) {
+ // TODO(gkurian): Condition this on being used internally within Google.
+ OpRegistry::Global()->Register(
+ [alg](OpRegistrationData* op_reg_data) -> Status {
+ return RegisterPerTableLoadOpsForAlgorithmBody(alg, true,
+ op_reg_data);
+ });
+ }
}
- int64 batch_size = config.batch_size();
- int64 num_tables = config.table_config_size();
- for (int table_id = 0; table_id < num_tables; ++table_id) {
- int64 width = config.table_config(table_id).width();
- int64 num_features = config.table_config(table_id).num_features();
- c->set_output(table_id, c->Matrix(batch_size * num_features, width));
+ // Retrieve ops
+ for (tpu::OptimizationAlgorithm alg : tpu::GetOptimizationAlgorithms()) {
+ OpRegistry::Global()->Register(
+ [alg](OpRegistrationData* op_reg_data) -> Status {
+ return RegisterPerTableRetrieveOpsForAlgorithmBody(alg, false,
+ op_reg_data);
+ });
+ tpu::GradientAccumulationSupport grad_accum_support;
+ TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support));
+ if (grad_accum_support == tpu::GradientAccumulationSupport::kSupported) {
+ // TODO(gkurian): Condition this on being used internally within Google.
+ OpRegistry::Global()->Register(
+ [alg](OpRegistrationData* op_reg_data) -> Status {
+ return RegisterPerTableRetrieveOpsForAlgorithmBody(alg, true,
+ op_reg_data);
+ });
+ }
}
- return Status::OK();
}
-} // namespace tpu_embedding_config_util
+} // namespace
-REGISTER_OP("TPUEmbeddingReceiveActivations")
- .Output("outputs: num_tables * float")
- .Attr("num_tables: int >= 1")
- .Attr("tpu_embedding_config: string")
+REGISTER_OP("RecvTPUEmbeddingActivations")
+ .Output("outputs: num_outputs * float")
+ .Attr("num_outputs: int >= 1")
+ .Attr("config: string")
.SetIsStateful()
- .SetShapeFn(tpu_embedding_config_util::ActivationShapes)
+ .SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
+ string config_string;
+ TF_RETURN_IF_ERROR(c->GetAttr("config", &config_string));
+ tpu::TPUEmbeddingConfiguration config;
+ if (!config.ParseFromString(config_string)) {
+ return errors::InvalidArgument("Malformed tpu_embedding_config.");
+ }
+ tpu::AddDefaultEmbeddingOutputLayoutIfNeeded(&config);
+ std::vector<TensorShapeProto> output_shapes;
+ TF_RETURN_IF_ERROR(ComputeOutputTensorShapes(config, &output_shapes));
+ if (c->num_outputs() != output_shapes.size()) {
+ return errors::InvalidArgument("num outputs != size of output shapes");
+ }
+ for (int i = 0; i < c->num_outputs(); ++i) {
+ shape_inference::ShapeHandle output_shape;
+ TF_RETURN_IF_ERROR(
+ c->MakeShapeFromShapeProto(output_shapes[i], &output_shape));
+ c->set_output(i, output_shape);
+ }
+ return Status::OK();
+ })
.Doc(R"doc(
An op that receives embedding activations on the TPU.
@@ -274,9 +403,9 @@ one ReceieveActivations op in the TPU graph.
outputs: A TensorList of embedding activations containing one Tensor per
embedding table in the model.
-num_tables: The number of output activation tensors, equal to the number of
+num_outputs: The number of output activation tensors, equal to the number of
embedding tables in the model.
-tpu_embedding_config: Serialized TPUEmbeddingConfiguration proto.
+config: Serialized TPUEmbeddingConfiguration proto.
)doc");
REGISTER_OP("TPUEmbeddingActivations")
@@ -306,10 +435,10 @@ lookup_id: Identifier of the set of embedding indices which produced these
activations.
)doc");
-REGISTER_OP("TPUEmbeddingSendGradients")
- .Input("gradients: num_tables * float32")
- .Attr("num_tables: int >= 1")
- .Attr("tpu_embedding_config: string")
+REGISTER_OP("SendTPUEmbeddingGradients")
+ .Input("inputs: N * float32")
+ .Attr("N: int >= 1")
+ .Attr("config: string")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
@@ -321,8 +450,107 @@ with respect to the embedding activations. The embedding tables are updated
from these gradients via the optimizer specified in the configuration given
to tpu.initialize_system.
-gradients: A TensorList of gradients with which to update embedding tables.
-tpu_embedding_config: Serialized TPUEmbeddingConfiguration proto.
+inputs: A TensorList of gradients with which to update embedding tables.
+config: Serialized TPUEmbeddingConfiguration proto.
+)doc");
+
+REGISTER_OP("EnqueueTPUEmbeddingIntegerBatch")
+ .Input("batch: N * int32")
+ .Attr("N: int")
+ .Attr("device_ordinal: int = -1")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+An op that enqueues a list of input batch tensors to TPUEmbedding.
+
+batch: A list of 1D tensors, one for each embedding table, containing the
+batch inputs represented as integers.
+device_ordinal: The TPU device to use. This should be -1 when the Op
+is running on a TPU device, and >= 0 when the Op is running on the CPU
+device.
+)doc");
+
+REGISTER_OP("EnqueueTPUEmbeddingSparseBatch")
+ .Input("sample_indices: N * int32")
+ .Input("embedding_indices: N * int32")
+ .Input("aggregation_weights: N * float32")
+ .Attr("N: int")
+ .Attr("device_ordinal: int = -1")
+ .Attr("combiners: list(string) = []")
+ .SetIsStateful()
+ .SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
+ std::vector<string> combiners;
+ TF_RETURN_IF_ERROR(c->GetAttr("combiners", &combiners));
+ int n;
+ TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
+ if (!combiners.empty() && combiners.size() != n) {
+ return errors::InvalidArgument("Invalid length of combiners. Have ",
+ combiners.size(), " but expected 0 or ",
+ n);
+ }
+
+ return Status::OK();
+ })
+ .Doc(R"doc(
+An op that enqueues TPUEmbedding input indices from a SparseTensor.
+
+This Op eases the porting of code that uses embedding_lookup_sparse(),
+although some Python preprocessing of the SparseTensor arguments to
+embedding_lookup_sparse() is required to produce the arguments to this Op,
+since only a single EnqueueTPUEmbedding Op is allowed per training step.
+
+The tensors at corresponding positions in the three input lists
+must have the same shape, i.e. rank 1 with dim_size() equal to the total
+number of lookups into the table described by the corresponding table_id.
+
+sample_indices: A list of Rank 1 Tensors specifying the training example and
+ feature to which the corresponding embedding_indices and aggregation_weights
+ values belong. sample_indices[i] must equal b * nf + f, where nf is the
+ number of features from the corresponding table, f is in [0, nf), and
+ b is in [0, training batch size).
+embedding_indices: A list of Rank 1 Tensors, indices into the embedding tables.
+aggregation_weights: A list of Rank 1 Tensors containing per sample -- i.e. per
+ (training example, feature) -- aggregation weights.
+device_ordinal: The TPU device to use. This should be -1 when the Op
+is running on a TPU device, and >= 0 when the Op is running on the CPU
+device.
+combiners: A list of string scalars whose values are 'mean', 'sum', or 'sqrtn'
+to specify how to normalize the embedding activations after weighted summation.
+)doc");
+
+REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
+ .Input("sample_indices: N * int32")
+ .Input("embedding_indices: N * int32")
+ .Input("aggregation_weights: N * float32")
+ .Attr("N: int")
+ .Attr("device_ordinal: int = -1")
+ .Attr("combiners: list(string) = []")
+ .Attr("table_ids: list(int)")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+This Op eases the porting of code that uses tf.nn.embedding_lookup_sparse().
+
+sample_indices[i], embedding_indices[i] and aggregation_weights[i] correspond
+to ith feature. table_ids[i] indicates which embedding table to look up ith
+feature.
+
+sample_indices: A list of Rank 1 Tensors, corresponds to sp_ids.indices[:,0] in
+embedding_lookup_sparse().
+embedding_indices: A list of Rank 1 Tensors, corresponds to sp_ids.values
+ in embedding_lookup_sparse().
+aggregation_weights: A list of Rank 1 Tensors, corresponds to sp_weights.values
+ in embedding_lookup_sparse().
+device_ordinal: The TPU device to use. This should be -1 when the Op
+is running on a TPU device, and >= 0 when the Op is running on the CPU
+device.
+combiners: A list of strings, one for each embedding table, specifying the
+reduction operation. Currently, 'sum', 'mean' and 'sqrtn' are supported. It is
+invalid to have the sum of the weights be 0 for 'mean' or the sum of the squared
+weights be 0 for 'sqrtn'. If combiners isn't passed, the default is to
+use 'sum' for all tables.
+table_ids: A list of int. table_ids[i] indicates which embedding table to look
+up ith feature.
)doc");
} // namespace tensorflow
diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
index b498599962..8e6e9aa0cd 100644
--- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
@@ -156,8 +156,7 @@ bool NewSession(const string& service_addr,
channel_args));
NewProfileSessionResponse new_session_response;
TF_QCHECK_OK(FromGrpcStatus(
- stub->NewSession(&context, new_session_request, &new_session_response)))
- << new_session_response.error_message();
+ stub->NewSession(&context, new_session_request, &new_session_response)));
std::cout << "Profile session succeed for host(s):"
<< str_util::Join(hostnames, ",") << std::endl;
diff --git a/tensorflow/contrib/tpu/profiler/op_profile.proto b/tensorflow/contrib/tpu/profiler/op_profile.proto
index 68cf510e71..292108f949 100644
--- a/tensorflow/contrib/tpu/profiler/op_profile.proto
+++ b/tensorflow/contrib/tpu/profiler/op_profile.proto
@@ -18,13 +18,15 @@ message Profile {
message Node {
string name = 1; // Semantics depend on contents.
Metrics metrics = 2; // May be omitted e.g. for fused instructions.
- repeated Node children = 3;
+ repeated Node children = 3; // Subjected to pruning.
// Details about what this node represents.
oneof contents {
InstructionCategory category = 4;
XLAInstruction xla = 5;
}
+
+ int32 num_children = 6; // Total number of children before pruning.
// A category of XLA instructions.
// name is a descriptive string, like "data formatting".
message InstructionCategory {
@@ -64,8 +66,8 @@ message Metrics {
// - it does not reveal the peak core FLOPS of the hardware
double flops = 2;
- // The VMEM bandwidth used to load operands from HBM, as a fraction of
- // thereotical VMEM bandwidth on the specific hardware.
+ // The memory bandwidth used to load operands, as a fraction of
+ // thereotical memory bandwidth on the specific hardware.
double memory_bandwidth = 3;
double raw_time = 11; // Elapsed core-time in picoseconds.
diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
index 438f442848..63641e00c5 100644
--- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
+++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
@@ -116,12 +116,13 @@ def main(unused_argv=None):
elif tpu_cluster_resolver is not None:
workers_list = get_workers_list(tpu_cluster_resolver)
- if not FLAGS.logdir:
+ if not FLAGS.logdir and not FLAGS.monitoring_level:
sys.exit('logdir must be provided.')
executable_path = os.path.join(os.path.dirname(__file__), EXECUTABLE)
- logdir = os.path.expandvars(os.path.expanduser(FLAGS.logdir))
cmd = [executable_path]
- cmd.append('--logdir=' + logdir)
+ if FLAGS.logdir is not None:
+ logdir = os.path.expandvars(os.path.expanduser(FLAGS.logdir))
+ cmd.append('--logdir=' + logdir)
cmd.append('--service_addr=' + service_addr)
cmd.append('--workers_list=' + workers_list)
cmd.append('--duration_ms=' + str(FLAGS.duration_ms))
diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
index d4ccb0f246..2415c46718 100644
--- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py
+++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
@@ -20,7 +20,7 @@ from __future__ import print_function
from setuptools import setup
-_VERSION = '1.10.0'
+_VERSION = '1.11.0'
CONSOLE_SCRIPTS = [
'capture_tpu_profile=cloud_tpu_profiler.main:run_main',
diff --git a/tensorflow/contrib/tpu/profiler/version.h b/tensorflow/contrib/tpu/profiler/version.h
index aee094177b..90d34b5ef1 100644
--- a/tensorflow/contrib/tpu/profiler/version.h
+++ b/tensorflow/contrib/tpu/profiler/version.h
@@ -16,6 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
#define TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
-#define TPU_PROFILER_VERSION "1.10.0"
+#define TPU_PROFILER_VERSION "1.11.0"
#endif // TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
diff --git a/tensorflow/contrib/tpu/proto/BUILD b/tensorflow/contrib/tpu/proto/BUILD
index 598b73b438..c20cab844c 100644
--- a/tensorflow/contrib/tpu/proto/BUILD
+++ b/tensorflow/contrib/tpu/proto/BUILD
@@ -10,12 +10,15 @@ load(
)
tf_proto_library(
- name = "tpu_embedding_config_proto",
+ name = "tpu_embedding_configuration_proto",
srcs = [
- "tpu_embedding_config.proto",
+ "tpu_embedding_configuration.proto",
],
cc_api_version = 2,
- protodeps = [":optimization_parameters_proto"],
+ protodeps = [
+ ":tpu_embedding_output_layout_proto",
+ ":optimization_parameters_proto",
+ ],
visibility = ["//visibility:public"],
)
@@ -29,6 +32,15 @@ tf_proto_library(
)
tf_proto_library(
+ name = "tpu_embedding_output_layout_proto",
+ srcs = [
+ "tpu_embedding_output_layout.proto",
+ ],
+ cc_api_version = 2,
+ visibility = ["//visibility:public"],
+)
+
+tf_proto_library(
name = "topology_proto",
srcs = [
"topology.proto",
diff --git a/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto b/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto
deleted file mode 100644
index 3476cc8953..0000000000
--- a/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto
+++ /dev/null
@@ -1,66 +0,0 @@
-syntax = "proto3";
-
-package tensorflow.tpu;
-
-import "tensorflow/contrib/tpu/proto/optimization_parameters.proto";
-
-// The TPUEmbeddingConfiguration contains specification of TPU Embedding lookups
-// and gradient updates separate from the TF Graph.
-message TPUEmbeddingConfiguration {
- // model_mode specifies whether the model is to be run in training or
- // inference. In inference mode, gradient updates to embedding tables are not
- // performed.
- enum ModelMode {
- INVALID = 0;
- TRAINING = 1;
- INFERENCE = 2;
- }
-
- ModelMode model_mode = 1;
-
- // num_hosts is the number of host CPU systems in the training/inference job.
- // Each embedding table must be sharded into num_hosts separate Variables,
- // placed separately on the num_hosts CPU devices in the cluster. Sharding
- // will be performed equivalently to the 'div' sharding_strategy option of
- // embedding_lookup() and embedding_lookup_sparse().
- int32 num_hosts = 2;
-
- // The total number of TensorNodes. This is equal to num_hosts times the
- // number of TensorNodes attached to each host.
- int32 num_tensornodes = 3;
-
- // The number of training examples per TensorNode.
- int32 batch_size = 4;
-
- // Each Embedding
- message TPUEmbeddingTable {
- // Name of the embedding table. This will be used to name Variables in the
- // Tensorflow Graph.
- string name = 1;
-
- // Number of rows of the embedding table. The Variable created to hold the
- // learned embedding table values will have shape (num_rows, width).
- int32 num_rows = 3;
-
- // Width of the embedding table. The Variable created to hold the
- // learned embedding table values will have shape (num_rows, width).
- int32 width = 4;
-
- // Number of distinct embedding activation vectors per training example
- // produced by lookups into this table during model evaluation. For each
- // table, the Graph will receive an activations Tensor of shape
- // (batch_size * table.num_features, table.width).
- // For example, num_features = 1 produces equivalent behavior to a single
- // tf.nn.embedding_lookup() call. In the case of 'multivalent' embeddings,
- // (i.e. tf.nn.embedding_lookup_sparse()) which compute weighted averages of
- // embedding table rows, num_features is the number of vectors produced
- // after averaging. In sequence models num_features is typically equal
- // to the sequence length, since each sequence element must be represented
- // separately to the convolutional or recurrent network.
- int32 num_features = 5;
-
- OptimizationParameters optimization_parameters = 6;
- }
-
- repeated TPUEmbeddingTable table_config = 5;
-}
diff --git a/tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto b/tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto
new file mode 100644
index 0000000000..da19b135d7
--- /dev/null
+++ b/tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto
@@ -0,0 +1,95 @@
+syntax = "proto3";
+
+package tensorflow.tpu;
+
+import "tensorflow/contrib/tpu/proto/optimization_parameters.proto";
+import "tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto";
+
+message TPUEmbeddingConfiguration {
+ // Description of the various embedding tables.
+ message TableDescriptor {
+ // Name of the table.
+ string name = 1;
+ // Size of the vocabulary (i.e., number of rows) in the table.
+ int32 vocabulary_size = 2;
+ // The embedding dimension (i.e., the width of the embedding table).
+ int32 dimension = 3;
+ // Number of features mapped to this table.
+ int32 num_features = 4;
+ // Details of the learning algorithm used to update the embedding
+ // parameters.
+ OptimizationParameters optimization_parameters = 5;
+ }
+ repeated TableDescriptor table_descriptor = 1;
+
+ // Mode. Should the embedding layer program be run for inference (just forward
+ // pass), training (both forward and backward pass) or just the backward_pass.
+ enum Mode {
+ UNSPECIFIED = 0;
+ INFERENCE = 1;
+ TRAINING = 2;
+ BACKWARD_PASS_ONLY = 3;
+ }
+ Mode mode = 2;
+
+ // Number of samples in each batch of embedding layer activations sent to
+ // the TensorCore.
+ int32 batch_size_per_tensor_core = 3;
+
+ // Number of TPU hosts used for inference/training.
+ int32 num_hosts = 4;
+
+ // Number of TensorCore used for inference/training.
+ int32 num_tensor_cores = 5;
+
+ // Sharding strategy of the embedding tables among the hosts.
+ // If the sharding_strategy is "mod", each id is assigned to host
+ // "id % num_hosts". For instance, 13 ids are split across 5 hosts as:
+ // [[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]].
+ // If the sharding_strategy is "div", ids are assigned to hosts in a
+ // contiguous manner. In this case, 13 ids are split across 5 hosts as:
+ // [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]].
+ // In both the strategies, if the id space does not evenly divide the number
+ // of hosts, each of the first "table_descriptor.num_ids % num_hosts" hosts
+ // will be assigned one more id.
+ // This partitioning strategy exactly follows that in the embedding_lookup
+ // TensorFlow function at tensorflow/python/ops/embedding_ops.py.
+ enum ShardingStrategy {
+ DIV_DEFAULT = 0;
+ MOD = 1;
+ }
+ ShardingStrategy sharding_strategy = 6;
+
+ // This parameter determines if the execution of the sparse core will be
+ // pipelined with that of the TensorCore. This parameter only affects results
+ // when mode=TRAINING. If mode=INFERENCE or BACKWARD_PASS_ONLY, this parameter
+ // does not affect execution and hence, is a don't care value.
+ //
+ // false: The execution of the sparse core is not pipelined with that of the
+ // TensorCore. The forward pass of every step on the sparse core is executed
+ // only after the backward pass of the previous step is complete. And the
+ // backward pass on the sparse core is executed only after the embedding
+ // gradients have been computed on the TensorCore on every step. This ensures
+ // that the activations on every step observe the gradient updates from the
+ // previous step on both the sparse core and the TensorCore.
+ //
+ // true: The execution of the sparse core is pipelined with that of the
+ // TensorCore. The forward pass of every step on the sparse core can be
+ // executed after the forward pass of the previous step is complete without
+ // waiting for the backward pass. This improves the utilization of the sparse
+ // core allowing it to process step N+1 while the embedding gradients for step
+ // N are computed on the TensorCore. The backward pass of every step on the
+ // sparse core is executed directly after the forward pass for the next step
+ // is complete. The drawback is that embedding activations for step N+1 do not
+ // observe the embedding gradient updates from step N. This could affect model
+ // quality if step N and N+1 involve the same set of embedding IDs. However,
+ // since the embedding updates are sparse, this is generally not considered a
+ // problem.
+ bool pipeline_execution_with_tensor_core = 7;
+
+ // Extended output layout information; if not provided, a compatibility mode
+ // will use defaults that match the old layout. Providing a value for this
+ // field is EXPERIMENTAL and most ways of filling it will probably break. Do
+ // not set it unless you know what you are doing.
+ TPUEmbeddingOutputLayout output_layout = 8;
+}
diff --git a/tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto b/tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto
new file mode 100644
index 0000000000..aed30b2f22
--- /dev/null
+++ b/tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto
@@ -0,0 +1,75 @@
+syntax = "proto3";
+
+package tensorflow.tpu;
+
+// In the comments here, "layout" refers to the top-level EmbeddingOutputLayout
+// proto contained in the TPUEmbeddingConfiguration.
+
+// The embedding output consists of a list of tensors, each specified by an
+// EmbeddingOutputTensor proto within the EmbeddingOutputLayout (the "output"
+// field). Each table and feature lookup is then placed into some number of
+// particular positions within some output tensor (identified by "tensor_index"
+// within OutputLocation). The tree of table lookups, feature lookups, and
+// output locations is specified by the
+// "table(table_id).feature(feature_id).output_location" repeated fields within
+// EmbeddingOutputLayout.
+
+message TPUEmbeddingOutputLayout {
+ // Location of one copy of the feature's data.
+ message OutputLocation {
+ // Which output tensor this copy of the feature will go into. Must be
+ // between 0 and layout.output_size().
+ int32 tensor_index = 1;
+
+ // Offset in dimension 0 for this feature copy. Must be between 0 and
+ // layout.output(tensor_index).dim0_size_per_sample().
+ int32 dim0_offset = 2;
+
+ // Offset in dimension 1 for this feature copy. Must be between 0 and
+ // layout.output(tensor_index).dim1_size() - table width; repeated or
+ // partially/fully overlapping values are allowed and results in the same
+ // range will be summed (with the gradients replicated in the backward
+ // pass).
+ int32 dim1_offset = 3;
+ }
+
+ // Description of the output placement for one feature.
+ message FeatureDescriptor {
+ // Typically, only one copy of each feature is used, but multiple are
+ // allowed and the same data will be copied to all of them (with the
+ // gradients summed in the backward pass).
+ repeated OutputLocation output_location = 1;
+ }
+
+ // Description of the output placement for features of one table.
+ message TableDescriptor {
+ // Output locations for each feature loaded from this table.
+ repeated FeatureDescriptor feature = 1;
+ }
+ // Output locations for each feature of each table.
+ repeated TableDescriptor table = 1;
+
+ // Data layout and shape computation information for a single output tensor.
+ // Any unused locations in the tensor will be filled with zeros, and
+ // corresponding gradients will be ignored.
+
+ // Size and layout information for 2-D tensors.
+ message TwoDOutputTensor {
+ // Multiplier for output dimension 0 size; used to match legacy format that
+ // stacks features within a sample in dimension 0.
+ int32 dim0_size_per_sample = 2;
+
+ // The size (in dimension 1) of this output tensor.
+ int32 dim1_size = 1;
+ }
+
+ // Format information for a single output tensor.
+ message EmbeddingOutputTensor {
+ oneof output_format {
+ TwoDOutputTensor two_d = 4;
+ }
+ }
+
+ // Shape and layout information for each tensor.
+ repeated EmbeddingOutputTensor output = 2;
+}
diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
index d92a0652bb..a1aee69691 100644
--- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py
+++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
@@ -95,7 +95,7 @@ if platform.system() != "Windows":
]
def cross_replica_sum(x, group_assignment=None, name=None):
- """Sum the input tensor accorss replicas according to group_assignment.
+ """Sum the input tensor across replicas according to group_assignment.
Args:
x: The local tensor to the sum.
@@ -112,6 +112,31 @@ if platform.system() != "Windows":
return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
+ def collective_permute(x, source_target_pairs, name=None):
+ """Permute the input tensor across replicas given source_target_pairs.
+
+ For each source_target_pair <a, b>, we send replica a's input to replica b.
+ Each replica id must only appear once in the source column. Also it must
+ only appear once in the target column.
+ For the replica id not in the target column, this op returns a zero tensor
+ with the same shape and dtype of the input x.
+
+ For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
+ source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs:
+ `[0, A, B, C]`.
+
+ Args:
+ x: The local tensor to be permuted.
+ source_target_pairs: 2d int lists with shape [num_pairs, 2].
+ source_target_pairs[i][0] represents the source replica id and
+ source_target_pairs[i][1] represents the target replica id.
+ name: Optional op name.
+
+ Returns:
+ A `Tensor` which is permuted.
+ """
+ return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name)
+
@ops.RegisterGradient("CrossReplicaSum")
def _cross_replica_sum_grad(op, grad):
# The gradient of a cross replica sum is also a cross-replica sum.
diff --git a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py
new file mode 100644
index 0000000000..e06a720e82
--- /dev/null
+++ b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py
@@ -0,0 +1,202 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ======================================
+
+"""Hook for asynchronous checkpointing.
+
+This hook dispatches checkpoint writing operations in a separate thread to
+allow execution to continue on the main thread.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import threading
+import time
+
+from tensorflow.core.util.event_pb2 import SessionLog
+
+from tensorflow.python.framework import meta_graph
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import session_run_hook
+from tensorflow.python.training import training_util
+from tensorflow.python.training.session_run_hook import SessionRunArgs
+from tensorflow.python.training.summary_io import SummaryWriterCache
+
+
+class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook):
+ """Saves checkpoints every N steps or seconds."""
+
+ def __init__(self,
+ checkpoint_dir,
+ save_secs=None,
+ save_steps=None,
+ saver=None,
+ checkpoint_basename="model.ckpt",
+ scaffold=None,
+ listeners=None):
+ """Initializes a `CheckpointSaverHook`.
+
+ Args:
+ checkpoint_dir: `str`, base directory for the checkpoint files.
+ save_secs: `int`, save every N secs.
+ save_steps: `int`, save every N steps.
+ saver: `Saver` object, used for saving.
+ checkpoint_basename: `str`, base name for the checkpoint files.
+ scaffold: `Scaffold`, use to get saver object.
+ listeners: List of `CheckpointSaverListener` subclass instances. Used for
+ callbacks that run immediately before or after this hook saves the
+ checkpoint.
+
+ Raises:
+ ValueError: One of `save_steps` or `save_secs` should be set.
+ ValueError: At most one of `saver` or `scaffold` should be set.
+ """
+ logging.info("Create CheckpointSaverHook.")
+ if saver is not None and scaffold is not None:
+ raise ValueError("You cannot provide both saver and scaffold.")
+ self._saver = saver
+ self._save_thread = None
+ self._checkpoint_dir = checkpoint_dir
+ self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
+ self._scaffold = scaffold
+ self._timer = basic_session_run_hooks.SecondOrStepTimer(
+ every_secs=save_secs, every_steps=save_steps)
+ self._listeners = listeners or []
+ self._steps_per_run = 1
+ self._summary_writer = None
+ self._global_step_tensor = None
+
+ def _set_steps_per_run(self, steps_per_run):
+ self._steps_per_run = steps_per_run
+
+ def begin(self):
+ self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
+ self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
+ if self._global_step_tensor is None:
+ raise RuntimeError(
+ "Global step should be created to use CheckpointSaverHook.")
+ for l in self._listeners:
+ l.begin()
+
+ def after_create_session(self, session, coord):
+ global_step = session.run(self._global_step_tensor)
+
+ # We do write graph and saver_def at the first call of before_run.
+ # We cannot do this in begin, since we let other hooks to change graph and
+ # add variables in begin. Graph is finalized after all begin calls.
+ training_util.write_graph(
+ ops.get_default_graph().as_graph_def(add_shapes=True),
+ self._checkpoint_dir, "graph.pbtxt")
+ saver_def = self._get_saver().saver_def if self._get_saver() else None
+ graph = ops.get_default_graph()
+ meta_graph_def = meta_graph.create_meta_graph_def(
+ graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
+ self._summary_writer.add_graph(graph)
+ self._summary_writer.add_meta_graph(meta_graph_def)
+ # The checkpoint saved here is the state at step "global_step".
+ self._save(session, global_step)
+ self._timer.update_last_triggered_step(global_step)
+
+ def before_run(self, run_context): # pylint: disable=unused-argument
+ return SessionRunArgs(self._global_step_tensor)
+
+ def after_run(self, run_context, run_values):
+ stale_global_step = run_values.results
+ if self._timer.should_trigger_for_step(stale_global_step +
+ self._steps_per_run):
+ # get the real value after train op.
+ global_step = run_context.session.run(self._global_step_tensor)
+ if self._timer.should_trigger_for_step(global_step):
+ self._timer.update_last_triggered_step(global_step)
+ if self._save(run_context.session, global_step):
+ run_context.request_stop()
+
+ def end(self, session):
+ if self._save_thread:
+ logging.info("Waiting for any pending checkpoints to finish.")
+ self._save_thread.join()
+
+ last_step = session.run(self._global_step_tensor)
+
+ # Save the last checkpoint synchronously if needed.
+ if last_step != self._timer.last_triggered_step():
+ self._save(session, last_step, asynchronous=False)
+
+ for l in self._listeners:
+ l.end(session, last_step)
+
+ def _save(self, session, step, asynchronous=True):
+ """Saves the latest checkpoint, returns should_stop."""
+
+ def _save_fn():
+ """Run the saver process."""
+ logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
+
+ start_time = time.time()
+ for l in self._listeners:
+ l.before_save(session, step)
+
+ self._get_saver().save(session, self._save_path, global_step=step)
+ self._summary_writer.add_session_log(
+ SessionLog(
+ status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
+ step)
+ end_time = time.time()
+ logging.info("Checkpoint actual writing time: (%.3f sec)",
+ end_time - start_time)
+ logging.info("Checkpoint finished for %d into %s.", step, self._save_path)
+
+ logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
+ for l in self._listeners:
+ l.before_save(session, step)
+
+ if not asynchronous:
+ _save_fn()
+ return
+
+ if self._save_thread is not None:
+ self._save_thread.join(timeout=0.1)
+ if self._save_thread.is_alive():
+ logging.info("Saver thread still in progress, skipping checkpoint.")
+ return
+
+ self._save_thread = threading.Thread(target=_save_fn)
+ self._save_thread.start()
+
+ def _get_saver(self):
+ if self._saver is not None:
+ return self._saver
+ elif self._scaffold is not None:
+ return self._scaffold.saver
+
+ # Get saver from the SAVERS collection if present.
+ collection_key = ops.GraphKeys.SAVERS
+ savers = ops.get_collection(collection_key)
+ if not savers:
+ raise RuntimeError(
+ "No items in collection {}. Please add a saver to the collection "
+ "or provide a saver or scaffold.".format(collection_key))
+ elif len(savers) > 1:
+ raise RuntimeError(
+ "More than one item in collection {}. "
+ "Please indicate which one to use by passing it to the constructor."
+ .format(collection_key))
+
+ self._saver = savers[0]
+ return savers[0]
diff --git a/tensorflow/contrib/tpu/python/tpu/device_assignment.py b/tensorflow/contrib/tpu/python/tpu/device_assignment.py
index 471b1fa46c..b9e2a4287a 100644
--- a/tensorflow/contrib/tpu/python/tpu/device_assignment.py
+++ b/tensorflow/contrib/tpu/python/tpu/device_assignment.py
@@ -72,13 +72,12 @@ class DeviceAssignment(object):
self._invert_topology(topology))
topology_rank = self._topology_tasks.ndim
- if core_assignment.ndim != topology_rank + 2:
- raise ValueError("core_assignment must be a rank {} numpy array".format(
- topology_rank + 2))
+ if core_assignment.ndim != 3:
+ raise ValueError("core_assignment must be a rank 3 numpy array, "
+ "got shape {}".format(core_assignment.shape))
self._num_replicas = core_assignment.shape[0]
- self._computation_shape = np.array(
- core_assignment.shape[1:-1], dtype=np.int32)
+ self._num_cores_per_replica = core_assignment.shape[1]
if core_assignment.shape[-1] != topology_rank:
raise ValueError(
@@ -107,18 +106,15 @@ class DeviceAssignment(object):
"""Computes a nested dict which maps task and logical core to replicas."""
task_and_cores_to_replicas = {}
for replica in xrange(core_assignment.shape[0]):
- for dx in xrange(core_assignment.shape[1]):
- for dy in xrange(core_assignment.shape[2]):
- for dz in xrange(core_assignment.shape[3]):
- x, y, z = core_assignment[replica, dx, dy, dz, :]
- task_id = topology_tasks[x, y, z]
- if task_id not in task_and_cores_to_replicas:
- task_and_cores_to_replicas[task_id] = {}
- logical_core = (dx, dy, dz)
- if logical_core not in task_and_cores_to_replicas[task_id]:
- task_and_cores_to_replicas[task_id][logical_core] = set()
-
- task_and_cores_to_replicas[task_id][logical_core].add(replica)
+ for logical_core in xrange(core_assignment.shape[1]):
+ x, y, z = core_assignment[replica, logical_core, :]
+ task_id = topology_tasks[x, y, z]
+ if task_id not in task_and_cores_to_replicas:
+ task_and_cores_to_replicas[task_id] = {}
+ if logical_core not in task_and_cores_to_replicas[task_id]:
+ task_and_cores_to_replicas[task_id][logical_core] = set()
+
+ task_and_cores_to_replicas[task_id][logical_core].add(replica)
task_to_sorted_replica_id = {}
@@ -136,23 +132,9 @@ class DeviceAssignment(object):
return self._topology
@property
- def computation_shape(self):
- """The computation shape.
-
- Returns:
- A rank-1 int32 numpy array with size equal to the TPU topology rank.
- Describes the logical shape in numbers of core of each replica of the
- computation in the TPU topology.
-
- Returns:
- The computation shape.
- """
- return self._computation_shape
-
- @property
def num_cores_per_replica(self):
"""The number of cores per replica."""
- return np.prod(self.computation_shape)
+ return self._num_cores_per_replica
@property
def num_replicas(self):
@@ -164,33 +146,22 @@ class DeviceAssignment(object):
"""The logical to physical core mapping.
Returns:
- A numpy array of rank `topology_rank + 2`, with shape
- `[num_replicas] + computation_shape + [topology_rank]`. Maps
- (replica, logical core coordinates) pairs to physical topology
- coordinates.
+ An integer numpy array of rank 3, with shape
+ `[num_replicas, num_cores_per_replica, topology_rank]`. Maps
+ (replica, logical core) pairs to physical topology coordinates.
"""
return self._core_assignment
def _coordinates(self, replica, logical_core):
"""Returns the physical topology coordinates of a logical core."""
- if logical_core is None:
- logical_core = np.array([0, 0, 0], np.int32)
- else:
- logical_core = np.asarray(logical_core)
-
- if any(logical_core < 0) or any(logical_core >= self.computation_shape):
- raise ValueError("Invalid core {}; computation shape is {}".format(
- logical_core, self.computation_shape))
-
- logical_offset = tuple([replica] + logical_core.tolist() + [slice(3)])
- return tuple(self.core_assignment[logical_offset])
+ return tuple(self.core_assignment[replica, logical_core, :])
def lookup_replicas(self, task_id, logical_core):
"""Lookup replica ids by task number and logical core.
Args:
task_id: TensorFlow task number.
- logical_core: A tuple of three integers which represents a logical core.
+ logical_core: An integer, identifying a logical core.
Returns:
A sorted list of the replicas that are attached to that task and
logical_core.
@@ -205,17 +176,17 @@ class DeviceAssignment(object):
"Can not find any replica in task: {} contains logical_core: {} ".
format(task_id, logical_core))
- def tpu_ordinal(self, replica=0, logical_core=None):
+ def tpu_ordinal(self, replica=0, logical_core=0):
"""Returns the ordinal of the TPU device assigned to a logical core."""
coordinates = self._coordinates(replica, logical_core)
return self._topology_devices[coordinates]
- def host_device(self, replica=0, logical_core=None, job=None):
+ def host_device(self, replica=0, logical_core=0, job=None):
"""Returns the CPU device attached to a logical core."""
coordinates = self._coordinates(replica, logical_core)
return _tpu_host_device_name(job, self._topology_tasks[coordinates])
- def tpu_device(self, replica=0, logical_core=None, job=None):
+ def tpu_device(self, replica=0, logical_core=0, job=None):
"""Returns the name of the TPU device assigned to a logical core."""
coordinates = self._coordinates(replica, logical_core)
return _tpu_device_name(job, self._topology_tasks[coordinates],
@@ -228,6 +199,8 @@ def device_assignment(topology,
num_replicas=1):
"""Computes a device_assignment of a computation across a TPU topology.
+ Attempts to choose a compact grid of cores for locality.
+
Returns a `DeviceAssignment` that describes the cores in the topology assigned
to each core of each replica.
@@ -240,12 +213,12 @@ def device_assignment(topology,
`initialize_system` using `Session.run`. Either a serialized
`TopologyProto` or a `Topology` object may be passed. Note: you must
evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor` here.
- computation_shape: A rank 1 int32 numpy array of size 3, describing the
- shape of the computation's block of cores. If None, the
- `computation_shape` is `[1, 1, 1]`.
- computation_stride: A rank 1 int32 numpy array of size 3, describing the
- inter-core spacing of the `computation_shape` cores in the TPU topology.
- If None, the `computation_stride` is `[1, 1, 1]`.
+ computation_shape: A rank 1 int32 numpy array with size equal to the
+ topology rank, describing the shape of the computation's block of cores.
+ If None, the `computation_shape` is `[1] * topology_rank`.
+ computation_stride: A rank 1 int32 numpy array of size `topology_rank`,
+ describing the inter-core spacing of the `computation_shape` cores in the
+ TPU topology. If None, the `computation_stride` is `[1] * topology_rank`.
num_replicas: The number of computation replicas to run. The replicas will
be packed into the free spaces of the topology.
@@ -271,21 +244,21 @@ def device_assignment(topology,
topology_rank = len(topology.mesh_shape)
mesh_shape = topology.mesh_shape
if computation_shape is None:
- computation_shape = np.array([1, 1, 1], dtype=np.int32)
+ computation_shape = np.array([1] * topology_rank, dtype=np.int32)
else:
computation_shape = np.asarray(computation_shape, dtype=np.int32)
if computation_stride is None:
- computation_stride = np.array([1, 1, 1], dtype=np.int32)
+ computation_stride = np.array([1] * topology_rank, dtype=np.int32)
else:
computation_stride = np.asarray(computation_stride, dtype=np.int32)
- if computation_shape.shape != (3,):
- raise ValueError("computation_shape must have shape [3]; got {}".format(
- computation_shape.shape))
- if computation_stride.shape != (3,):
- raise ValueError("computation_stride must have shape [3]; got {}".format(
- computation_stride.shape))
+ if computation_shape.shape != (topology_rank,):
+ raise ValueError("computation_shape must have shape [{}]; got {}".format(
+ topology_rank, computation_shape.shape))
+ if computation_stride.shape != (topology_rank,):
+ raise ValueError("computation_stride must have shape [{}]; got {}".format(
+ topology_rank, computation_stride.shape))
if any(computation_shape < 1):
raise ValueError(
@@ -315,28 +288,41 @@ def device_assignment(topology,
num_replicas, max_replicas, computation_shape, computation_stride,
mesh_shape))
- # Choose a compact layout for the cores. Choose the smaller dimension in the
- # topology to be close to the square root of the number of replicas.
- num_chips = int(math.ceil(num_replicas / replica_counts[2]))
- target_size = int(math.ceil(math.sqrt(num_chips)))
-
- # Prefer an even size, if possible. Odd numbered rows head back towards the
- # first column, so it's best if the last row has an odd index.
- if target_size % 2 != 0:
- target_size -= 1
- y_size = min(replica_counts[1], target_size)
- if y_size * replica_counts[0] < num_chips:
- y_size = replica_counts[1]
+ def ceil_of_ratio(n, m):
+ return (n + m - 1) // m
+
+ replica_shape = [0] * topology_rank
+ if num_replicas > 0:
+ remaining_replicas = num_replicas
+ remaining_dims = topology_rank
+
+ # Choose dimensions as close to an equal cube as possible, in order of
+ # increasing dimension size. By visiting dimensions in increasing size, we
+ # assign the most constrained dimension first, so we won't make infeasible
+ # choices.
+ #
+ # As a secondary sort order, visit the dimensions in reverse order. This
+ # means we try to use both cores on the same chip in preference to two cores
+ # on different chips.
+ for x, ni in sorted(((x, -i) for (i, x) in enumerate(replica_counts))):
+ i = -ni
+ target_size = int(math.ceil(remaining_replicas**(1.0 / remaining_dims)))
+ replica_shape[i] = min(target_size, x)
+ remaining_replicas = ceil_of_ratio(remaining_replicas, replica_shape[i])
+ remaining_dims -= 1
+
+ assert remaining_replicas == 1 and remaining_dims == 0
# Assigns an offset to each replica such that no two replicas overlap.
- replica_offsets = np.full([num_replicas, 3], -1, dtype=np.int32)
+ replica_offsets = np.full([num_replicas, topology_rank], -1, dtype=np.int32)
for replica in xrange(num_replicas):
- # Chooses a replica number in X/Y/Z axes.
- z = replica % replica_counts[2]
- t = replica // replica_counts[2]
- y = t % y_size
- x = t // y_size
- replica_pos = np.array([x, y, z], dtype=np.int32)
+ # Chooses a replica number in each axis.
+ t = replica
+ pos = []
+ for dim in replica_shape[::-1]:
+ pos.append(t % dim)
+ t //= dim
+ replica_pos = np.array(pos[::-1], dtype=np.int32)
# Determines where that replica starts in each axis.
outer = replica_pos // computation_stride
@@ -351,6 +337,6 @@ def device_assignment(topology,
indices = np.concatenate(
[i[..., np.newaxis] for i in np.meshgrid(*indices, indexing="ij")],
axis=-1)
- assignment = (
- indices + replica_offsets[:, np.newaxis, np.newaxis, np.newaxis, :])
+ indices = indices.reshape((-1, topology_rank))
+ assignment = indices + replica_offsets[:, np.newaxis, :]
return DeviceAssignment(topology, core_assignment=assignment)
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index d8c3872363..696656e840 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -25,10 +25,9 @@ flattened = tf.keras.layers.Flatten()(c1)
logits = tf.keras.layers.Dense(10, activation='softmax')(flattened)
model = tf.keras.Model(inputs=[image], outputs=[logits])
-strategy = keras_support.TPUDistributionStrategy(num_cores_per_host=8)
-model = keras_support.tpu_model(model,
- strategy=strategy,
- tpu_name_or_address=tpu_name)
+resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=tpu_name)
+strategy = keras_support.TPUDistributionStrategy(resolver)
+model = keras_support.tpu_model(model, strategy=strategy)
# Only TF optimizers are currently supported.
model.compile(optimizer=tf.train.AdamOptimizer(), ...)
@@ -47,12 +46,12 @@ from __future__ import print_function
import abc
import collections
-import contextlib
import re
import sys
import time
import numpy as np
+import six
from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver as tpu_cluster_resolver_lib
from tensorflow.contrib.framework.python.framework import experimental
@@ -69,6 +68,7 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -76,6 +76,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks as cbks
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import models
from tensorflow.python.keras import optimizers as keras_optimizers
from tensorflow.python.keras.engine import base_layer
@@ -89,34 +90,34 @@ from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
-_SESSIONS = {}
-
-
-def tpu_session(cluster_resolver):
+def setup_tpu_session(cluster_resolver):
"""Construct or return a `tf.Session` connected to the given cluster."""
- global _SESSIONS
master = cluster_resolver.master()
- if master not in _SESSIONS:
- cluster_spec = cluster_resolver.cluster_spec()
- config = config_pb2.ConfigProto(isolate_session_state=True)
- if cluster_spec:
- config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
- logging.info('Connecting to: %s', master)
- graph = ops.Graph()
- session = tf_session.Session(graph=graph, target=master, config=config)
- with graph.as_default():
- session.run(tpu.initialize_system())
+ # Use the existing session if we're already connected to this TPU
+ if (K.get_session()._target == master and
+ getattr(K.get_session(), '_tpu_initialized', None)):
+ return
+
+ cluster_spec = cluster_resolver.cluster_spec()
+ config = config_pb2.ConfigProto(isolate_session_state=True)
+ if cluster_spec:
+ config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
- _SESSIONS[master] = session
- return _SESSIONS[master]
+ logging.info('Initialize')
+ tpu_session = tf_session.Session(target=master, config=config)
+ tpu_session.run(tpu.initialize_system())
+ tpu_session._tpu_initialized = True
+ # N.B. We have to call `K.set_session()` AND set our session as the
+ # TF default. `K.get_session()` surprisingly does not return the value
+ # supplied by K.set_session otherwise.
+ K.set_session(tpu_session)
-def reset_tpu_sessions():
- _SESSIONS.clear()
try:
from scipy.sparse import issparse # pylint: disable=g-import-not-at-top
@@ -133,9 +134,7 @@ def get_tpu_system_metadata(tpu_cluster_resolver):
cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
tpu_system_metadata = (
tpu_system_metadata_lib._query_tpu_system_metadata(
- master,
- cluster_def=cluster_def,
- query_topology=False))
+ master, cluster_def=cluster_def, query_topology=False))
return tpu_system_metadata
@@ -156,6 +155,8 @@ class TPUDistributionStrategy(object):
replication, typically using all avaiable TPU cores. If overwrites as
`True`, force the model replication using single core, i.e., no
replication.
+ Raises:
+ Exception: No TPU Found on the given worker.
"""
if tpu_cluster_resolver is None:
@@ -171,7 +172,8 @@ class TPUDistributionStrategy(object):
for device in metadata.devices:
if 'TPU:0' in device.name:
self._worker_name = worker_re.search(device.name).group(1)
- break
+ return
+ raise Exception('No TPU found on given worker.')
def _make_assignment_for_model(self, cpu_model):
"""Makes a `TPUAssignment` for the passed in `cpu_model`."""
@@ -182,8 +184,7 @@ class TPUDistributionStrategy(object):
'Degrading to a single core.')
num_cores = 1
- return TPUAssignment(
- worker_name=self._worker_name, num_cores=num_cores)
+ return TPUAssignment(worker_name=self._worker_name, num_cores=num_cores)
class TPUAssignment(object):
@@ -229,6 +230,39 @@ class TPUEmbedding(embeddings.Embedding):
return math_ops.tensordot(inputs, self.embeddings, 1)
+def _cross_replica_concat(tensor, core_id, num_cores, name):
+ """Concatenate `tensor` across cores.
+
+ Args:
+ tensor: The tensor to be concatenated. Must be [int32 and float32].
+ core_id: Tensor indicating the current TPU core.
+ num_cores: Python int. The total number of TPU cores in the system.
+ name: The string name to print for debugging.
+
+ Returns:
+ The same concatenated Tensor on each core.
+ """
+
+ input_dtype = tensor.dtype
+ if input_dtype not in [dtypes.float32, dtypes.int32]:
+ raise TypeError('For model replication, only (float32 and int32) is '
+ 'supported for model outputs and targets. Got {} for '
+ '{}.'.format(input_dtype, name))
+
+ batch_size = tensor.shape[0]
+ mask = math_ops.to_float(math_ops.equal(range(num_cores), core_id))
+ mask = array_ops.reshape(mask, [num_cores] + [1] * tensor.shape.ndims)
+ result = mask * math_ops.to_float(tensor)
+ local_tensor_with_holes = array_ops.reshape(result,
+ [-1] + result.shape.as_list()[2:])
+ concat_tensor = tpu_ops.cross_replica_sum(local_tensor_with_holes)
+ concat_tensor.set_shape((num_cores * batch_size,) + tuple(tensor.shape[1:]))
+
+ if concat_tensor != input_dtype:
+ concat_tensor = math_ops.cast(concat_tensor, input_dtype)
+ return concat_tensor
+
+
class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
"""An optimizer that averages gradients across TPU shards."""
@@ -246,9 +280,9 @@ class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
super(KerasCrossShardOptimizer, self).__init__()
self._name = name
self._opt = opt
+ logging.info('KerasCrossShard: %s %s', self._opt, self._opt.weights)
def get_updates(self, loss, params):
- logging.info('Get updates: %s', loss)
self._opt.get_gradients = self.get_gradients
return self._opt.get_updates(loss, params)
@@ -257,17 +291,15 @@ class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
grads = super(KerasCrossShardOptimizer, self).get_gradients(loss, params)
return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
- def set_weights(self, weights):
- # TODO(power): Figure out whether we really need this given there is no
- # caller for this API yet.
- self._opt.set_weights()
-
def get_weights(self):
return self._opt.get_weights()
- @property
- def lr(self):
- return self._opt.lr
+ def get_config(self):
+ return self._opt.get_config()
+
+ # Defer remaining operations to the underlying optimizer
+ def __getattr__(self, key):
+ return getattr(self._opt, key)
class TPUModelOp(
@@ -293,6 +325,24 @@ def _replicated_optimizer(opt):
return KerasCrossShardOptimizer(opt)
+def _clone_optimizer(optimizer, config=None):
+ """Returns a cloned optimizer with the provided optimizer.config or config."""
+ if not isinstance(optimizer, keras_optimizers.Optimizer):
+ # In the first call to tpu_model(model), Keras may not have wrapped the TF
+ # optimizer in the TFOptimizer helper, e.g., the given model isn't compiled
+ # or optimizer isn't set, and later generated tpu_model compiles with a TF
+ # optimizer.
+ return optimizer
+
+ if isinstance(optimizer, keras_optimizers.TFOptimizer):
+ return keras_optimizers.TFOptimizer(optimizer.optimizer)
+
+ if config is None:
+ config = optimizer.get_config()
+ logging.info('Cloning %s %s', optimizer.__class__.__name__, config)
+ return optimizer.__class__.from_config(config)
+
+
class TPURewriteContext(object):
"""Prepare the environment for a Keras model during `tpu.rewrite`.
@@ -381,6 +431,7 @@ class TPURewriteContext(object):
return (r, q)
else:
raise ValueError('Invalid shape passed to qr: %s' % input_shape)
+
gen_linalg_ops.qr = qr
ops.name_scope = _name_scope
@@ -396,9 +447,9 @@ class TPURewriteContext(object):
gen_linalg_ops.qr = self._default_qr
-class SizedInfeed(collections.namedtuple('SizedInfeed',
- ['sharded_infeed_tensors',
- 'infeed_ops'])):
+class SizedInfeed(
+ collections.namedtuple('SizedInfeed',
+ ['sharded_infeed_tensors', 'infeed_ops'])):
"""Represents an instantiation of the infeed ops for a concrete input shape.
sharded_infeed_tensors: A data structure of Tensors used to represent the
@@ -584,12 +635,13 @@ class TPUNumpyInfeedManager(TPUInfeedManager):
infeed_tensors, [spec.shape for spec in input_specs],
name='infeed-enqueue-%s-%d' % (execution_mode, shard_id),
device_ordinal=shard_id))
- return SizedInfeed(infeed_ops=infeed_op,
- sharded_infeed_tensors=shard_infeed_tensors)
+ return SizedInfeed(
+ infeed_ops=infeed_op, sharded_infeed_tensors=shard_infeed_tensors)
class TPUDatasetInfeedManager(TPUInfeedManager):
"""Manages infeed for a `tf.data.Dataset` into a TPU computation.
+
"""
class DatasetInfeedInstance(TPUInfeedInstance):
@@ -607,18 +659,17 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
return {}
# pylint: disable=redefined-outer-name
- def __init__(self, dataset, tpu_assignment, tpu_session):
+ def __init__(self, dataset, tpu_assignment, mode):
"""Constructs a TPUDatasetInfeedManager.
- Must be called within a `KerasTPUModel.tpu_session` context!
-
Args:
dataset: A `tf.data.Dataset` to infeed.
tpu_assignment: The `TPUAssignment` used to configure the
Keras TPU model.
- tpu_session: The `tf.Session` object used for running the TPU model.
+ mode: ModeKeys enum.
"""
self._verify_dataset_shape(dataset)
+
self._dataset = dataset
self._tpu_assignment = tpu_assignment
dummy_x_shape = dataset.output_shapes[0].as_list()
@@ -626,7 +677,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
dummy_y_shape = dataset.output_shapes[1].as_list()
dummy_y_shape[0] *= tpu_assignment.num_towers
self._iterator = dataset.make_initializable_iterator()
- tpu_session.run(self._iterator.initializer)
+ K.get_session().run(self._iterator.initializer)
self._get_next_ops = []
ctrl_deps = []
@@ -639,10 +690,10 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
- self._dummy_x = np.zeros(dummy_x_shape,
- dtype=dataset.output_types[0].as_numpy_dtype)
- self._dummy_y = np.zeros(dummy_y_shape,
- dtype=dataset.output_types[1].as_numpy_dtype)
+ self._dummy_x = np.zeros(
+ dummy_x_shape, dtype=dataset.output_types[0].as_numpy_dtype)
+ self._dummy_y = np.zeros(
+ dummy_y_shape, dtype=dataset.output_types[1].as_numpy_dtype)
input_specs = []
if isinstance(self._iterator.output_shapes, tuple):
@@ -658,6 +709,10 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
self._iterator.output_types)
input_specs.append(spec)
+ # Pre-process the inputs and get_next_ops before caching.
+ input_specs, self._get_next_ops = (
+ _inject_tpu_inputs_for_dataset(
+ tpu_assignment, mode, input_specs, self._get_next_ops))
self._infeed_instance = self.DatasetInfeedInstance(input_specs)
def _verify_dataset_shape(self, dataset):
@@ -669,9 +724,8 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
raise ValueError('The dataset must return a tuple of tf.Tensors, '
'instead it returns: %s' % dataset.output_classes)
if len(dataset.output_classes) != 2:
- raise ValueError(
- 'The dataset must return a 2-element tuple, got '
- '%s output classes instead.' % (dataset.output_classes,))
+ raise ValueError('The dataset must return a 2-element tuple, got '
+ '%s output classes instead.' % (dataset.output_classes,))
for i, cls in enumerate(dataset.output_classes):
if cls != ops.Tensor:
raise ValueError('The dataset returned a non-Tensor type (%s) at '
@@ -680,8 +734,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
if not shape:
raise ValueError('The dataset returns a scalar tensor in '
'tuple index %d. Did you forget to batch? '
- '(Output shapes: %s).' % (i,
- dataset.output_shapes))
+ '(Output shapes: %s).' % (i, dataset.output_shapes))
for j, dim in enumerate(shape):
if dim.value is None:
if j == 0:
@@ -721,8 +774,72 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
[spec.shape for spec in input_specs],
name='infeed-enqueue-%s-%d' % (execution_mode, shard_id),
device_ordinal=shard_id))
- return SizedInfeed(infeed_ops=infeed_ops,
- sharded_infeed_tensors=shard_infeed_tensors)
+ return SizedInfeed(
+ infeed_ops=infeed_ops, sharded_infeed_tensors=shard_infeed_tensors)
+
+
+def _inject_tpu_inputs_for_dataset(tpu_assignment, mode,
+ input_specs, get_next_ops):
+ """Append core information to the set of dataset inputs."""
+ # This is used during compilation to identify the current TPU core and enable
+ # concatenation operations across cores.
+ if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
+ return input_specs, get_next_ops
+
+ # Dataset inputs operate on per core basis.
+ per_core_batch_size = input_specs[0].shape.as_list()[0]
+
+ # Insert, at head, the tensor for core_id.
+ assert len(get_next_ops) == tpu_assignment.num_towers
+ for i in range(tpu_assignment.num_towers):
+ core_id_constant = constant_op.constant(
+ np.array([i] * per_core_batch_size).astype('int32'),
+ dtype=dtypes.int32,
+ name='cord_id_constant')
+ get_next_ops[i] = [core_id_constant] + list(get_next_ops[i])
+
+ # Insert the input spec at head also.
+ input_specs = [tensor_spec.TensorSpec([per_core_batch_size], dtypes.int32)
+ ] + input_specs
+
+ return input_specs, get_next_ops
+
+
+def _inject_tpu_inputs_for_infeed(tpu_assignment, mode,
+ core_id_place_holder, input_tensors, inputs):
+ """Append core information to the set of inputs."""
+ # This is used during compilation to identify the current TPU core and enable
+ # concatenation operations across cores.
+ if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
+ return input_tensors, inputs
+
+ # Puts a place holder in input spec.
+ input_tensors = [core_id_place_holder] + input_tensors
+
+ # Now fill the core id. For `num_cores` = 2, `batch_size` = 8, we fill the
+ # core id inputs as [0, 0, 0, 0, 1, 1, 1, 1], so each core sees its core id
+ # (duplicated).
+ num_cores = tpu_assignment.num_towers
+ per_core_batch_size = inputs[0].shape[0] // num_cores
+ core_ids = np.arange(num_cores).repeat(per_core_batch_size)
+ inputs = [core_ids] + inputs
+ return input_tensors, inputs
+
+
+def _read_tpu_coreid_from_infeed(mode, infeed_tensors):
+ """Popping out the core ids from infeed."""
+ if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
+ return None, infeed_tensors
+
+ if len(infeed_tensors) <= 1:
+ raise RuntimeError(
+ 'The infeed tensors on TPU core has only {} tensors. '
+ 'This is not expected. Please report a bug.\nTensors: {}'.format(
+ len(infeed_tensors), infeed_tensors))
+
+ core_id = infeed_tensors[0][0] # Pop out the scalar version.
+ rest = infeed_tensors[1:]
+ return core_id, rest
class TPUFunction(object):
@@ -743,12 +860,11 @@ class TPUFunction(object):
self._tpu_assignment = tpu_assignment
self._compilation_cache = {}
self._cloned_model = None
-
- # Copy optimizer configuration. This is done prior to `_specialize_model`
- # as the configuration may require evaluating variables in the CPU session.
- self._optimizer_config = None
- if not isinstance(self.model.optimizer, keras_optimizers.TFOptimizer):
- self._optimizer_config = self.model.optimizer.get_config()
+ self._cloned_optimizer = None
+ # Create a placeholder for the TPU core ID. Cache the placeholder to avoid
+ # modifying the graph for every batch.
+ self._core_id_place_holder = array_ops.placeholder(
+ dtype=dtypes.int32, shape=[1], name='core_id')
def _specialize_model(self, input_specs, infeed_manager):
"""Specialize `self.model` (a Keras model) for the given input shapes."""
@@ -775,6 +891,10 @@ class TPUFunction(object):
shapes=[spec.shape for spec in input_specs],
name='infeed-%s' % self.execution_mode)
+ core_id, infeed_tensors = (
+ _read_tpu_coreid_from_infeed(
+ mode=self.execution_mode, infeed_tensors=infeed_tensors))
+
assert len(infeed_tensors) == len(infeed_layers), (
'Infeed inputs did not match model: %s vs %s' % (infeed_layers,
infeed_tensors))
@@ -790,35 +910,65 @@ class TPUFunction(object):
tpu_targets.append(tensor)
# Clone our CPU model, running within the TPU device context.
+ #
+ # We use the id of the original model as a key to avoid weight collisions
+ # (if a user re-runs the same model multiple times, in e.g. Colab).
with TPURewriteContext(tpu_input_map):
- with variable_scope.variable_scope('tpu_model_%s' % id(self.model)):
+ with variable_scope.variable_scope('tpu_%s' % id(self.model)):
with keras_tpu_variables.replicated_scope(
self._tpu_assignment.num_towers):
- self._cloned_model = models.clone_model(self.model)
+ if not self._cloned_optimizer:
+ self._cloned_optimizer = _clone_optimizer(
+ self.model.cpu_optimizer)
- # Create a copy of the optimizer for this graph.
- if isinstance(self.model.optimizer, keras_optimizers.TFOptimizer):
- cloned_optimizer = keras_optimizers.TFOptimizer(
- self.model.optimizer.optimizer)
- else:
- logging.info('Cloning %s %s', self.model.optimizer.__class__.__name__,
- self._optimizer_config)
- cloned_optimizer = self.model.optimizer.__class__.from_config(
- self._optimizer_config)
+ self._cloned_model = models.clone_model(self.model)
- if is_training or is_test:
- self._cloned_model.compile(
- optimizer=_replicated_optimizer(cloned_optimizer),
- loss=self.model.loss,
- loss_weights=self.model.loss_weights,
- metrics=self.model.metrics,
- weighted_metrics=self.model.weighted_metrics,
- target_tensors=tpu_targets,
- )
+ # When running on more than one core, concatenate outputs at the end
+ # of processing. In backprop stage, the gradients will be
+ # calculdated according to the local inputs as gradient of
+ # cross-replica-concat being zero for any outputs other than those
+ # from mlocal core so the loss calculation is identical.
+ num_towers = self.model._tpu_assignment.num_towers
+ if num_towers > 1 and (is_training or is_test):
+ new_outputs = [
+ _cross_replica_concat(
+ o, core_id, num_towers,
+ name='model output ({})'.format(o.name))
+ for o in self._cloned_model.outputs
+ ]
+ self._cloned_model.outputs = new_outputs
+ tpu_targets = [
+ _cross_replica_concat(
+ tensor,
+ core_id,
+ num_towers,
+ name='model target ({})'.format(tensor.name))
+ for tensor in tpu_targets
+ ]
+
+ if is_training or is_test:
+ self._cloned_model.compile(
+ optimizer=_replicated_optimizer(self._cloned_optimizer),
+ loss=self.model.loss,
+ loss_weights=self.model.loss_weights,
+ metrics=metrics_module.clone_metrics(self.model.metrics),
+ weighted_metrics=metrics_module.clone_metrics(
+ self.model.weighted_metrics),
+ target_tensors=tpu_targets,
+ )
# Compute our outfeed depending on the execution mode
if is_training:
- self._cloned_model._make_train_function()
+ if not isinstance(self._cloned_optimizer, keras_optimizers.TFOptimizer):
+ # For Keras optimizer, we try to place the variable weights on the TPU
+ # device. Keras creates optimizer variables (e.g. momentum values for
+ # the Momentum optimizer) when _make_train_function is invoked.
+ with keras_tpu_variables.replicated_variable_for_optimizer(
+ self._tpu_assignment.num_towers):
+ self._cloned_model._make_train_function()
+ else:
+ self._cloned_model._make_train_function()
+
self._outfeed_spec = [
tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
for tensor in self._cloned_model.train_function.outputs
@@ -923,6 +1073,7 @@ class TPUFunction(object):
for x, mgr in self.model._numpy_to_infeed_manager_list:
if inputs[0] is x:
return mgr
+
return TPUNumpyInfeedManager(self.model._tpu_assignment)
def _tpu_model_ops_for_input_specs(self, input_specs, infeed_manager):
@@ -947,13 +1098,14 @@ class TPUFunction(object):
# unique input shape.
shape_key = tuple([tuple(spec.shape.as_list()) for spec in input_specs])
if shape_key not in self._compilation_cache:
- with self.model.tpu_session():
- logging.info('New input shapes; (re-)compiling: mode=%s, %s',
- self.execution_mode, input_specs)
- new_tpu_model_ops = self._specialize_model(input_specs,
- infeed_manager)
- self._compilation_cache[shape_key] = new_tpu_model_ops
- self._test_model_compiles(new_tpu_model_ops)
+ logging.info(
+ 'New input shapes; (re-)compiling: mode=%s '
+ '(# of cores %d), %s', self.execution_mode,
+ self._tpu_assignment.num_towers, input_specs)
+ new_tpu_model_ops = self._specialize_model(input_specs,
+ infeed_manager)
+ self._compilation_cache[shape_key] = new_tpu_model_ops
+ self._test_model_compiles(new_tpu_model_ops)
return self._compilation_cache[shape_key]
@@ -970,15 +1122,29 @@ class TPUFunction(object):
# Note: this condition is possible during the prologue or epilogue of the
# pipelined loop.
return None, None
- # Strip sample weight from inputs
+
+ if (self.model.uses_learning_phase and
+ not isinstance(K.learning_phase(), int)):
+ # Remove the learning_phase flag at the end. We currently hard code the
+ # learning_phase in TPUFunction.
+ assert isinstance(inputs[-1], int), (
+ 'Expect the final element be learning_phase flag. Got {}'.format(
+ inputs[-1]))
+ inputs = inputs[:-1]
+
if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or
self.execution_mode == model_fn_lib.ModeKeys.EVAL):
+ # Strip sample weight from inputs.
input_tensors = self.model._feed_inputs + self.model._feed_targets
- inputs = inputs[:len(input_tensors)]
- return input_tensors, inputs
else:
input_tensors = self.model._feed_inputs
- return input_tensors, inputs
+
+ inputs = inputs[:len(input_tensors)]
+ input_tensors, inputs = (
+ _inject_tpu_inputs_for_infeed(
+ self._tpu_assignment, self.execution_mode,
+ self._core_id_place_holder, input_tensors, inputs))
+ return input_tensors, inputs
def _process_outputs(self, outfeed_outputs):
"""Processes the outputs of a model function execution.
@@ -1038,11 +1204,10 @@ class TPUFunction(object):
# Initialize our TPU weights on the first compile.
self.model._initialize_weights(self._cloned_model)
- with self.model.tpu_session() as session:
- _, _, outfeed_outputs = session.run([
- tpu_model_ops.infeed_op, tpu_model_ops.execute_op,
- tpu_model_ops.outfeed_op
- ], infeed_dict)
+ _, _, outfeed_outputs = K.get_session().run([
+ tpu_model_ops.infeed_op, tpu_model_ops.execute_op,
+ tpu_model_ops.outfeed_op
+ ], infeed_dict)
return self._process_outputs(outfeed_outputs)
def pipeline_run(self, cur_step_inputs, next_step_inputs):
@@ -1074,8 +1239,8 @@ class TPUFunction(object):
next_step_infeed_manager = self._lookup_infeed_manager(next_step_inputs)
cur_step_infeed_manager = self._lookup_infeed_manager(cur_step_inputs)
- if (next_step_infeed_manager is not None
- and cur_step_infeed_manager is not None):
+ if (next_step_infeed_manager is not None and
+ cur_step_infeed_manager is not None):
assert type(next_step_infeed_manager) is type(cur_step_infeed_manager)
next_input_tensors, next_step_inputs = (
@@ -1100,14 +1265,12 @@ class TPUFunction(object):
infeed_dict = None
if cur_infeed_instance and cur_input_tensors and cur_step_infeed_manager:
- cur_input_specs = cur_infeed_instance.make_input_specs(
- cur_input_tensors)
+ cur_input_specs = cur_infeed_instance.make_input_specs(cur_input_tensors)
cur_tpu_model_ops = self._tpu_model_ops_for_input_specs(
cur_input_specs, cur_step_infeed_manager)
- if (next_infeed_instance
- and next_input_tensors
- and next_step_infeed_manager):
+ if (next_infeed_instance and next_input_tensors and
+ next_step_infeed_manager):
next_input_specs = next_infeed_instance.make_input_specs(
next_input_tensors)
next_tpu_model_ops = self._tpu_model_ops_for_input_specs(
@@ -1118,26 +1281,24 @@ class TPUFunction(object):
self.model._initialize_weights(self._cloned_model)
if next_tpu_model_ops and cur_tpu_model_ops:
- with self.model.tpu_session() as session:
- _, _, outfeed_outputs = session.run([
- next_tpu_model_ops.infeed_op, cur_tpu_model_ops.execute_op,
- cur_tpu_model_ops.outfeed_op
- ], infeed_dict)
+ _, _, outfeed_outputs = K.get_session().run([
+ next_tpu_model_ops.infeed_op, cur_tpu_model_ops.execute_op,
+ cur_tpu_model_ops.outfeed_op
+ ], infeed_dict)
return self._process_outputs(outfeed_outputs)
+
if cur_tpu_model_ops:
- with self.model.tpu_session() as session:
- _, outfeed_outputs = session.run([
- cur_tpu_model_ops.execute_op, cur_tpu_model_ops.outfeed_op])
+ _, outfeed_outputs = K.get_session().run(
+ [cur_tpu_model_ops.execute_op, cur_tpu_model_ops.outfeed_op])
return self._process_outputs(outfeed_outputs)
+
if next_tpu_model_ops:
- with self.model.tpu_session() as session:
- session.run(next_tpu_model_ops.infeed_op, infeed_dict)
+ K.get_session().run(next_tpu_model_ops.infeed_op, infeed_dict)
return None
raise RuntimeError('Internal error: both current & next tpu_model_ops '
'were None')
-
class KerasTPUModel(models.Model):
"""TPU compatible Keras model wrapper."""
@@ -1164,8 +1325,6 @@ class KerasTPUModel(models.Model):
self._tpu_model = None
self._tpu_weights_initialized = False
- self._session = tpu_session(cluster_resolver)
-
# If the input CPU model has already been compiled, compile our TPU model
# immediately.
if self._cpu_model.optimizer:
@@ -1202,15 +1361,16 @@ class KerasTPUModel(models.Model):
if target_tensors:
raise ValueError('target_tensors is not supported for TPU execution.')
+ self._cpu_model.compile(
+ _clone_optimizer(optimizer), loss,
+ metrics_module.clone_metrics(metrics), loss_weights, sample_weight_mode,
+ metrics_module.clone_metrics(weighted_metrics), target_tensors,
+ **kwargs)
+
super(KerasTPUModel, self).compile(optimizer, loss, metrics, loss_weights,
sample_weight_mode, weighted_metrics,
target_tensors, **kwargs)
- if not self._cpu_model.optimizer:
- self._cpu_model.compile(optimizer, loss, metrics, loss_weights,
- sample_weight_mode, weighted_metrics,
- target_tensors, **kwargs)
-
def fit(self,
x=None,
y=None,
@@ -1243,8 +1403,8 @@ class KerasTPUModel(models.Model):
'https://github.com/tensorflow/tpu/tree/master/models/experimental'
'/keras')
if callable(x):
- with self.tpu_session() as sess,\
- ops.device('/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
+ with ops.device('/job:%s/device:CPU:0' %
+ self._tpu_assignment.worker_name):
dataset = x()
if steps_per_epoch is None:
raise ValueError('When using tf.data as input to a model, you '
@@ -1252,8 +1412,8 @@ class KerasTPUModel(models.Model):
if y is not None:
raise ValueError('When using tf.data as input to a model, y must be '
'None')
- infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
- sess)
+ infeed_manager = TPUDatasetInfeedManager(
+ dataset, self._tpu_assignment, model_fn_lib.ModeKeys.TRAIN)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
x = infeed_manager.dummy_x
@@ -1269,26 +1429,24 @@ class KerasTPUModel(models.Model):
'https://github.com/tensorflow/tpu/tree/master/models/experimental'
'/keras')
if callable(validation_data):
- with self.tpu_session() as sess:
- dataset = validation_data()
- if validation_steps is None:
- raise ValueError('When using tf.data as validation for a model, you '
- 'should specify the validation_steps argument.')
- infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
- sess)
- # Use dummy numpy inputs for the rest of Keras' shape checking. We
- # intercept them when building the model.
- val_x = infeed_manager.dummy_x
- val_y = infeed_manager.dummy_y
- infeed_managers.append((val_x, infeed_manager))
- validation_data = (val_x, val_y)
+ dataset = validation_data()
+ if validation_steps is None:
+ raise ValueError('When using tf.data as validation for a model, you '
+ 'should specify the validation_steps argument.')
+ infeed_manager = TPUDatasetInfeedManager(
+ dataset, self._tpu_assignment, model_fn_lib.ModeKeys.EVAL)
+ # Use dummy numpy inputs for the rest of Keras' shape checking. We
+ # intercept them when building the model.
+ val_x = infeed_manager.dummy_x
+ val_y = infeed_manager.dummy_y
+ infeed_managers.append((val_x, infeed_manager))
+ validation_data = (val_x, val_y)
self._numpy_to_infeed_manager_list = infeed_managers
try:
if not kwargs.get('_pipeline', True):
- logging.info(
- 'Running non-pipelined training loop (`_pipeline=%s`).',
- kwargs['_pipeline'])
+ logging.info('Running non-pipelined training loop (`_pipeline=%s`).',
+ kwargs['_pipeline'])
kwargs.pop('_pipeline')
return super(KerasTPUModel, self).fit(
x,
@@ -1344,50 +1502,32 @@ class KerasTPUModel(models.Model):
'https://github.com/tensorflow/tpu/tree/master/models/experimental'
'/keras')
if callable(x):
- with self.tpu_session() as sess:
- dataset = x()
- if steps is None:
- raise ValueError('When using tf.data as input to a model, you '
- 'should specify the steps argument.')
- if y is not None:
- raise ValueError('When using tf.data as input to a model, y must be '
- 'None')
- infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
- sess)
- # Use dummy numpy inputs for the rest of Keras' shape checking. We
- # intercept them when building the model.
- x = infeed_manager.dummy_x
- y = infeed_manager.dummy_y
- infeed_managers.append((x, infeed_manager))
+ dataset = x()
+ if steps is None:
+ raise ValueError('When using tf.data as input to a model, you '
+ 'should specify the steps argument.')
+ if y is not None:
+ raise ValueError('When using tf.data as input to a model, y must be '
+ 'None')
+ infeed_manager = TPUDatasetInfeedManager(
+ dataset, self._tpu_assignment, model_fn_lib.ModeKeys.EVAL)
+ # Use dummy numpy inputs for the rest of Keras' shape checking. We
+ # intercept them when building the model.
+ x = infeed_manager.dummy_x
+ y = infeed_manager.dummy_y
+ infeed_managers.append((x, infeed_manager))
self._numpy_to_infeed_manager_list = infeed_managers
try:
- return super(KerasTPUModel, self).evaluate(
- x,
- y,
- batch_size,
- verbose,
- sample_weight,
- steps)
+ return super(KerasTPUModel, self).evaluate(x, y, batch_size, verbose,
+ sample_weight, steps)
finally:
self._numpy_to_infeed_manager_list = []
- def _pipeline_fit(self,
- x,
- y,
- batch_size,
- epochs,
- verbose,
- callbacks,
- validation_split,
- validation_data,
- shuffle,
- class_weight,
- sample_weight,
- initial_epoch,
- steps_per_epoch,
- validation_steps,
- **kwargs):
+ def _pipeline_fit(self, x, y, batch_size, epochs, verbose, callbacks,
+ validation_split, validation_data, shuffle, class_weight,
+ sample_weight, initial_epoch, steps_per_epoch,
+ validation_steps, **kwargs):
# Similar to super.fit(...), but modified to support software pipelining.
# Backwards compatibility
@@ -1415,13 +1555,8 @@ class KerasTPUModel(models.Model):
# Prepare validation data
val_x, val_y, val_sample_weights = self._prepare_validation_data(
- validation_data,
- validation_split,
- validation_steps,
- x,
- y,
- sample_weights,
- batch_size)
+ validation_data, validation_split, validation_steps, x, y,
+ sample_weights, batch_size)
return self._pipeline_fit_loop(
x,
y,
@@ -1594,8 +1729,8 @@ class KerasTPUModel(models.Model):
for i in indices_for_conversion_to_dense:
ins_batch[i] = ins_batch[i].toarray()
- outs = f.pipeline_run(cur_step_inputs=ins_last_batch,
- next_step_inputs=ins_batch)
+ outs = f.pipeline_run(
+ cur_step_inputs=ins_last_batch, next_step_inputs=ins_batch)
ins_last_batch = ins_batch
if batch_index == 0:
@@ -1667,8 +1802,8 @@ class KerasTPUModel(models.Model):
next_step_inputs = ins
else:
next_step_inputs = None
- outs = f.pipeline_run(cur_step_inputs=ins,
- next_step_inputs=next_step_inputs)
+ outs = f.pipeline_run(
+ cur_step_inputs=ins, next_step_inputs=next_step_inputs)
except errors.OutOfRangeError:
logging.warning('Your dataset iterator ran out of data; '
'interrupting training. Make sure that your '
@@ -1688,25 +1823,21 @@ class KerasTPUModel(models.Model):
break
if do_validation:
- val_outs = training_arrays.test_loop(self,
- val_inputs,
- val_targets,
- sample_weights=val_sample_weights,
- steps=validation_steps,
- verbose=0)
+ val_outs = training_arrays.test_loop(
+ self,
+ val_inputs,
+ val_targets,
+ sample_weights=val_sample_weights,
+ steps=validation_steps,
+ verbose=0)
if not isinstance(val_outs, list):
val_outs = [val_outs]
# Same labels assumed.
for l, o in zip(self.metrics_names, val_outs):
epoch_logs['val_' + l] = o
- def _prepare_validation_data(self,
- validation_data,
- validation_split,
- validation_steps,
- x,
- y,
- sample_weights,
+ def _prepare_validation_data(self, validation_data, validation_split,
+ validation_steps, x, y, sample_weights,
batch_size):
"""Prepares the validation dataset.
@@ -1764,8 +1895,10 @@ class KerasTPUModel(models.Model):
x, val_x = (slice_arrays(x, 0, split_at), slice_arrays(x, split_at))
y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at))
- sample_weights, val_sample_weights = (slice_arrays(
- sample_weights, 0, split_at), slice_arrays(sample_weights, split_at))
+ sample_weights, val_sample_weights = (
+ slice_arrays(sample_weights, 0, split_at),
+ slice_arrays(sample_weights, split_at)
+ )
elif validation_steps:
val_x = []
val_y = []
@@ -1777,11 +1910,20 @@ class KerasTPUModel(models.Model):
return val_x, val_y, val_sample_weights
+ @property
+ def optimizer(self):
+ if self._tpu_model:
+ return self._tpu_model.optimizer
+ return self._cpu_model.optimizer
+
+ @optimizer.setter
+ def optimizer(self, optimizer):
+ self._optimizer = optimizer
+
def _make_train_function(self):
if not self.train_function:
self.train_function = TPUFunction(
- self,
- model_fn_lib.ModeKeys.TRAIN,
+ self, model_fn_lib.ModeKeys.TRAIN,
tpu_assignment=self._tpu_assignment)
return self.train_function
@@ -1816,18 +1958,48 @@ class KerasTPUModel(models.Model):
self._tpu_weights_initialized = True
weights = self._cpu_model.get_weights()
- with self.tpu_session():
- logging.info('Setting weights on TPU model.')
- cloned_model.set_weights(weights)
+
+ if isinstance(self.cpu_optimizer, keras_optimizers.TFOptimizer):
+ cpu_optimizer_config = {}
+ else:
+ cpu_optimizer_config = self.cpu_optimizer.get_config()
+
+ logging.info('Setting weights on TPU model.')
+ cloned_model.set_weights(weights)
+ for k, v in six.iteritems(cpu_optimizer_config):
+ opt_var = getattr(self._tpu_model.optimizer, k)
+ if isinstance(opt_var, variables.Variable):
+ logging.info('CPU -> TPU %s: %s {%s}', k, v, K.get_value(opt_var))
+ K.get_session().run(opt_var.assign(v))
+ else:
+ logging.warning('Cannot update non-variable config: %s', k)
+
+ @property
+ def cpu_optimizer(self):
+ return self._cpu_model.optimizer
def sync_to_cpu(self):
"""Copy weights from the CPU, returning a synchronized CPU model."""
- if self._tpu_weights_initialized:
- with self.tpu_session():
- logging.info('Copying TPU weights to the CPU')
- tpu_weights = self._tpu_model.get_weights()
+ if not self._tpu_weights_initialized:
+ return self._cpu_model
+
+ logging.info('Copying TPU weights to the CPU')
+ tpu_weights = self._tpu_model.get_weights()
- self._cpu_model.set_weights(tpu_weights)
+ # TFOptimizers have no configurable options
+ if isinstance(self.cpu_optimizer, keras_optimizers.TFOptimizer):
+ tpu_optimizer_config = {}
+ else:
+ tpu_optimizer_config = self._tpu_model.optimizer.get_config()
+
+ self._cpu_model.set_weights(tpu_weights)
+ for k, v in six.iteritems(tpu_optimizer_config):
+ logging.info('TPU -> CPU %s: %s', k, v)
+ opt_var = getattr(self.cpu_optimizer, k)
+ if isinstance(opt_var, variables.Variable):
+ K.get_session().run(opt_var.assign(v))
+ else:
+ logging.warning('Cannot update non-variable config: %s', k)
return self._cpu_model
@@ -1848,26 +2020,6 @@ class KerasTPUModel(models.Model):
self._cpu_model.set_weights(weights)
self._tpu_weights_initialized = False
- @contextlib.contextmanager
- def tpu_session(self):
- """Yields a TPU session and sets it as the default Keras session."""
- with self._session.graph.as_default():
- default_session = K.get_session()
- # N.B. We have to call `K.set_session()` AND set our session as the
- # TF default. `K.get_session()` surprisingly does not return the value
- # supplied by K.set_session otherwise.
- K.set_session(self._session)
- with self._session.as_default():
- yield self._session
- K.set_session(default_session)
-
- def shutdown(self):
- # TODO(b/111364423): Actually shut down the system.
- logging.info('Skipping shutting down TPU system.')
- # with self.tpu_session() as session:
- # session.run(tpu.shutdown_system())
- self._session.close()
-
# pylint: disable=bad-continuation
def _validate_shapes(model):
@@ -1908,7 +2060,9 @@ Output shape: %(output_shape)s
@experimental
def tpu_model(model, strategy=None):
- """Copy `model` along with weights to the TPU. Returns a TPU model.
+ """Copy `model` along with weights to the TPU.
+
+ Returns a TPU model.
Usage:
```
@@ -1923,21 +2077,16 @@ def tpu_model(model, strategy=None):
model.compile(
optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0),
...)
- model.shutdown()
```
Args:
- model: A `KerasTPUModel`.
+ model: A `tf.keras.Model` instance.
strategy: `TPUDistributionStrategy`. The strategy to use for replicating
- model across multiple TPU cores.
+ model across multiple TPU cores.
Returns:
A new `KerasTPUModel` instance.
"""
- # Force initialization of the CPU model.
- model.get_weights()
- model.reset_states()
-
_validate_shapes(model)
# TODO(xiejw): Validate TPU model. TPUModel only?
# TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset?
@@ -1951,4 +2100,34 @@ def tpu_model(model, strategy=None):
'`strategy` must have type `tf.contrib.tpu.TPUDistributionStrategy`. '
'Got: {}'.format(type(strategy)))
- return KerasTPUModel(cpu_model=model, strategy=strategy)
+ # If the model has already been initialized, grab the optimizer configuration
+ # and model weights before entering the TPU session.
+ if model.optimizer:
+ if (isinstance(model.optimizer, keras_optimizers.Optimizer) and not
+ isinstance(model.optimizer, keras_optimizers.TFOptimizer)):
+ optimizer_config = model.optimizer.get_config()
+ else:
+ optimizer_config = None
+ model_weights = model.get_weights()
+ else:
+ model_weights = None
+
+ setup_tpu_session(strategy._tpu_cluster_resolver)
+
+ # Force initialization of the CPU model in the TPU session.
+ cpu_model = models.clone_model(model)
+ if model.optimizer:
+ cpu_model.compile(
+ _clone_optimizer(model.optimizer, optimizer_config),
+ model.loss,
+ metrics_module.clone_metrics(model.metrics),
+ model.loss_weights,
+ model.sample_weight_mode,
+ metrics_module.clone_metrics(model.weighted_metrics),
+ )
+
+ if model_weights:
+ cpu_model.set_weights(model_weights)
+ cpu_model.reset_states()
+
+ return KerasTPUModel(cpu_model=cpu_model, strategy=strategy)
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
index 170977d8ab..598da7418e 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
@@ -25,10 +25,15 @@ from __future__ import print_function
import contextlib
+import numpy as np
+
from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
+from tensorflow.python.keras import backend
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
@@ -285,3 +290,51 @@ def replicated_scope(num_replicas):
return variable_scope.variable_scope(
"", custom_getter=_replicated_variable_getter)
+
+
+@contextlib.contextmanager
+def replicated_variable_for_optimizer(num_replicas):
+ """Context manager for optimizer weights. Overrides K.variable."""
+ if num_replicas == 1:
+ yield
+ return
+
+ try:
+ old_v = backend.variable
+
+ def opt_variable(value, dtype=None, name=None, constraint=None):
+ """Instantiates a variable and returns it."""
+ if dtype is None:
+ dtype = backend.floatx()
+
+ variables = []
+ for i in range(num_replicas):
+ # Keras holds the variables in optimizer class instance , so the name
+ # does not matter here. ResourceVariable constructor will find a unique
+ # name (including name=None) for each replica.
+ with ops.device("device:TPU:{}".format(i)):
+ v = resource_variable_ops.ResourceVariable(
+ value,
+ dtype=dtypes_module.as_dtype(dtype),
+ name=name,
+ constraint=constraint)
+ variables.append(v)
+ name = "replicate_{}_{}".format("variable" if name is None else name,
+ ops.uid())
+ v = ReplicatedVariable(name, variables)
+
+ # pylint: disable=protected-access
+
+ if isinstance(value, np.ndarray):
+ v._keras_shape = value.shape
+ elif hasattr(value, "shape"):
+ v._keras_shape = backend.int_shape(value)
+ v._uses_learning_phase = False
+ backend.track_variable(v)
+ return v
+
+ backend.variable = opt_variable
+ yield
+
+ finally:
+ backend.variable = old_v
diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py
index 3e91e2df32..05264f5a46 100644
--- a/tensorflow/contrib/tpu/python/tpu/session_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/session_support.py
@@ -41,6 +41,29 @@ class CoordinatorShutdownException(Exception):
pass
+def _make_heartbeat_op(session, device, request_ph):
+ """Return a heartbeat op or None if heartbeats are not supported by device."""
+ try:
+ # Test if we can connect in a isolated graph + session
+ with ops.Graph().as_default():
+ with session_lib.Session(target=session.sess_str) as temp_session:
+ with ops.device(device):
+ heartbeat_op = tpu_ops.worker_heartbeat('')
+ options = config_pb2.RunOptions(timeout_in_ms=5000)
+ temp_session.run(heartbeat_op, options=options)
+ except errors.InvalidArgumentError as _:
+ logging.warning('Error running heartbeat on %s', device)
+ return None
+ except errors.DeadlineExceededError as _:
+ logging.warning('Timeout connecting to %s when testing heartbeat', device)
+ return None
+
+ # If we successfully connected and pinged the worker, go ahead and construct
+ # the operation.
+ with ops.device(device):
+ return tpu_ops.worker_heartbeat(request_ph)
+
+
class WorkerHeartbeatManager(object):
"""Manages the status/heartbeat monitor for a set of workers."""
@@ -72,30 +95,27 @@ class WorkerHeartbeatManager(object):
name='worker_heartbeat_request', dtype=dtypes.string)
heartbeat_ops = []
+ kept_devices = []
for device in devices:
- with ops.device(device):
- heartbeat_ops.append(tpu_ops.worker_heartbeat(request_placeholder))
+ heartbeat_op = _make_heartbeat_op(session, device, request_placeholder)
+ if heartbeat_op is not None:
+ kept_devices.append(device)
+ heartbeat_ops.append(heartbeat_op)
+ else:
+ logging.warning('Heartbeat support not available for %s', device)
- return WorkerHeartbeatManager(session, devices, heartbeat_ops,
+ return WorkerHeartbeatManager(session, kept_devices, heartbeat_ops,
request_placeholder)
- def heartbeat_supported(self):
- """Returns True if heartbeat operations are supported on all workers."""
- try:
- # Send ping to verify worker has heartbeat support.
- self.ping()
- return True
- except errors.InvalidArgumentError as _:
- return False
+ def num_workers(self):
+ return len(self._devices)
def configure(self, message):
"""Configure heartbeat manager for all devices.
Args:
message: `event_pb2.WorkerHeartbeatRequest`
-
Returns: `None`
-
"""
logging.info('Configuring worker heartbeat: %s',
text_format.MessageToString(message))
@@ -155,7 +175,7 @@ class WorkerHeartbeatManager(object):
def all_worker_devices(session):
"""Return a list of devices for each worker in the system."""
devices = session.list_devices()
- return [device.name for device in devices if 'CPU' in device.name]
+ return [device.name for device in devices if ':CPU:' in device.name]
class WatchdogManager(threading.Thread):
@@ -184,7 +204,6 @@ class WatchdogManager(threading.Thread):
"""Initialize a watchdog manager.
Args:
-
session: Session connected to worker devices. A cloned session and graph
will be created for managing worker pings.
devices: Set of devices to monitor. If none, all workers will be
@@ -277,16 +296,14 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook):
target=training_session.sess_str, graph=self._graph)
self._workers = WorkerHeartbeatManager.from_devices(
self._session, all_worker_devices(self._session))
- self._heartbeat_supported = self._workers.heartbeat_supported()
+ self._heartbeat_supported = self._workers.num_workers() > 0
if self._heartbeat_supported:
self._workers.configure(
event_pb2.WorkerHeartbeatRequest(
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
else:
logging.warn(
- 'Worker heartbeats not supported by all workers. No failure '
- 'handling will be enabled.'
- )
+ 'No workers support hearbeats. Failure handling will be disabled.')
def saver(self):
if self._saver:
@@ -303,8 +320,7 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook):
logging.error(
'Multiple savers in the SAVERS collection. On-demand checkpointing '
'will be disabled. Pass an explicit `saver` to the constructor to '
- 'override this behavior.'
- )
+ 'override this behavior.')
return None
return savers[0]
diff --git a/tensorflow/contrib/tpu/python/tpu/topology.py b/tensorflow/contrib/tpu/python/tpu/topology.py
index 1fb26e701a..ab89c6aa8c 100644
--- a/tensorflow/contrib/tpu/python/tpu/topology.py
+++ b/tensorflow/contrib/tpu/python/tpu/topology.py
@@ -112,6 +112,11 @@ class Topology(object):
return self._mesh_shape
@property
+ def mesh_rank(self):
+ """Returns the number of dimensions in the mesh."""
+ return len(self._mesh_shape)
+
+ @property
def device_coordinates(self):
"""Describes the mapping from TPU devices to topology coordinates.
@@ -125,6 +130,16 @@ class Topology(object):
"""
return self._device_coordinates
+ @property
+ def num_tasks(self):
+ """Returns the number of TensorFlow tasks in the TPU slice."""
+ return self._device_coordinates.shape[0]
+
+ @property
+ def num_tpus_per_task(self):
+ """Returns the number of TPU devices per task in the TPU slice."""
+ return self._device_coordinates.shape[1]
+
def serialized(self):
"""Returns the serialized form of the topology."""
if self._serialized is None:
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 815a087a24..883e08bf47 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -76,7 +76,7 @@ def initialize_system(embedding_config=None, job=None):
"""Initializes a distributed TPU system for use with TensorFlow.
Args:
- embedding_config: If not None, an `EmbeddingLayerConfiguration` proto
+ embedding_config: If not None, a `TPUEmbeddingConfiguration` proto
describing the desired configuration of the hardware embedding lookup
tables. If embedding_config is None, no hardware embeddings can be used.
job: The job (the XXX in TensorFlow device specification /job:XXX) that
@@ -562,13 +562,14 @@ def split_compile_and_replicate(computation,
device_assignment.core_assignment.flatten().tolist()
}
# TODO(phawkins): remove this case after the forward compatibility window
- # expires on 2018-10-6.
- if api_compat.forward_compatible(2018, 10, 6):
+ # expires on 2018-10-5.
+ if api_compat.forward_compatible(2018, 10, 5):
metadata_kwargs["num_cores_per_replica"] = (
device_assignment.num_cores_per_replica)
else:
- metadata_kwargs["computation_shape"] = (
- device_assignment.computation_shape.tolist())
+ metadata_kwargs["computation_shape"] = [
+ device_assignment.num_cores_per_replica
+ ]
if ((not isinstance(inputs, list)) or
any(not isinstance(inp, (list, tuple)) for inp in inputs)):
@@ -660,6 +661,10 @@ def split_compile_and_replicate(computation,
# be less confusing to clients if they knowingly choose to use resource
# variables.
# Partitioned variables is not supported (b/112311320).
+ vscope = variable_scope.get_variable_scope()
+ saved_use_resource = vscope.use_resource
+ saved_custom_getter = vscope.custom_getter
+
def custom_getter(getter, name, *args, **kwargs):
"""Variables on TPU have a few restrictions."""
partitioner = kwargs["partitioner"]
@@ -670,12 +675,10 @@ def split_compile_and_replicate(computation,
"`partitioner` that is {} for variable {}. "
"Setting `partitioner` to `None`."
.format(partitioner, name))
- return getter(name, *args, **kwargs)
-
- vscope = variable_scope.get_variable_scope()
-
- saved_use_resource = vscope.use_resource
- saved_custom_getter = vscope.custom_getter
+ if saved_custom_getter is None:
+ return getter(name, *args, **kwargs)
+ else:
+ return saved_custom_getter(getter, name, *args, **kwargs)
vscope.set_use_resource(True)
vscope.set_custom_getter(custom_getter)
@@ -847,8 +850,12 @@ def shard(computation,
if num_shards <= 0:
raise ValueError("num_shards must be a positive integer.")
+ inputs = [] if inputs is None else inputs
+ if not isinstance(inputs, list):
+ raise TypeError("tpu.shard()'s inputs must be a list of Tensors or None.")
+
# Converts inputs to Tensors.
- inputs = [] if inputs is None else [ops.convert_to_tensor(x) for x in inputs]
+ inputs = [ops.convert_to_tensor(x) for x in inputs]
if input_shard_axes is None:
input_shard_axes = [0] * len(inputs)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
index 18e0abdda2..9f8d147068 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
@@ -32,7 +32,6 @@ from tensorflow.python.platform import tf_logging as logging
_TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV
_SERVICE_KEY = run_config_lib._SERVICE_KEY
_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name'
-_NUM_CORES_PER_HOST = 8
# pylint: enable=protected-access
@@ -103,7 +102,7 @@ class TPUConfig(
input mode.
Raises:
- ValueError: If `num_cores_per_replica` is not 1, 2, 4 or 8.
+ ValueError: If `num_cores_per_replica` is not 1, 2, 4, 8 or 16.
"""
def __new__(cls,
@@ -139,9 +138,9 @@ class TPUConfig(
# Check num_cores_per_replica
if num_cores_per_replica is not None:
- if num_cores_per_replica not in [1, 2, 4, 8]:
+ if num_cores_per_replica not in [1, 2, 4, 8, 16]:
raise ValueError(
- 'num_cores_per_replica must be 1, 2, 4, or 8; got {}'.format(
+ 'num_cores_per_replica must be 1, 2, 4, 8, or 16; got {}'.format(
str(num_cores_per_replica)))
# per_host_input_for_training may be True, False, or integer in [1..3].
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
index 2326fe97a8..b2fe0a6888 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
@@ -86,7 +86,7 @@ class TPURunConfigTest(test.TestCase):
def test_fail_with_invalid_num_cores_per_replica(self):
with self.assertRaisesRegexp(
- ValueError, 'num_cores_per_replica must be 1, 2, 4, or 8;'
+ ValueError, 'num_cores_per_replica must be 1, 2, 4, 8, or 16;'
' got 7'):
tpu_config_lib.TPUConfig(num_cores_per_replica=7)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index 19359cb612..7cfb6c38fa 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -35,7 +35,8 @@ _NUM_CORES_TO_COMPUTATION_SHAPE = {
1: [1, 1, 1],
2: [1, 1, 2],
4: [1, 2, 2],
- 8: [2, 2, 2]
+ 8: [2, 2, 2],
+ 16: [4, 2, 2],
}
@@ -117,6 +118,11 @@ class TPUContext(object):
return self._internal_ctx.num_hosts
@property
+ def current_host(self):
+ """The current host index for the TPU system."""
+ return self._invocation_index
+
+ @property
def num_of_replicas_per_host(self):
"""The number of replicas for each host."""
if self._internal_ctx.model_parallelism_enabled:
@@ -298,6 +304,7 @@ class _InternalTPUContext(object):
@property
def num_of_replicas_per_host(self):
+ """Return the number of replicas per host."""
if self.model_parallelism_enabled:
return self.num_replicas // self.num_hosts
else:
@@ -538,8 +545,8 @@ class _InternalTPUContext(object):
"""
if self.model_parallelism_enabled:
# We put both enqueue/dequeue ops at tpu.core(0) in each replica.
- replica = self.device_assignment.lookup_replicas(
- host_id, (0, 0, 0))[shard_index_in_host]
+ replica = self.device_assignment.lookup_replicas(host_id,
+ 0)[shard_index_in_host]
return self.device_assignment.tpu_ordinal(replica=replica)
else:
return shard_index_in_host % self.num_of_cores_per_host
@@ -580,6 +587,17 @@ class _InternalTPUContext(object):
raise ValueError(message)
+ if self._config.tpu_config.num_cores_per_replica:
+ num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
+ num_cores_per_host = self._get_tpu_system_metadata().num_of_cores_per_host
+ if num_cores_per_replica > num_cores_per_host:
+ raise ValueError(
+ 'The num of cores required by the model parallelism, specified by '
+ 'TPUConfig.num_cores_per_replica, is larger than the '
+ 'num_cores_per_host. num_cores_per_replica: {}, '
+ 'num_cores_per_host: {}'.format(num_cores_per_replica,
+ num_cores_per_host))
+
if mode == model_fn_lib.ModeKeys.TRAIN:
if (self._train_batch_size % num_replicas != 0 and
not self.is_input_broadcast_with_iterators()):
@@ -599,8 +617,8 @@ class _InternalTPUContext(object):
.format(self._eval_batch_size, num_replicas))
if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
raise ValueError(
- 'TPUEstimator.evaluate should be running on single TPU worker. '
- 'got {}.'.format(num_hosts))
+ 'TPUEstimator.evaluate should be running on single TPU'
+ ' instead of a Pod.')
else:
assert mode == model_fn_lib.ModeKeys.PREDICT
if self._predict_batch_size is None:
@@ -685,7 +703,7 @@ def _get_tpu_context(config, train_batch_size, eval_batch_size,
config.tpu_config.num_cores_per_replica is None):
logging.warning(
'Setting TPUConfig.num_shards==1 is an unsupported behavior. '
- 'Please fix as soon as possible (leaving num_shards as None.')
+ 'Please fix as soon as possible (leaving num_shards as None.)')
return _OneCoreTPUContext(config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 23c54511ca..764d85877a 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -231,7 +231,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
`metric_fn` runs on CPU to generate metrics and `tensors` represents the
`Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`.
To be precise, TPU evaluation expects a slightly different signature from the
- @{tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a
+ `tf.estimator.Estimator`. While `EstimatorSpec.eval_metric_ops` expects a
dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`.
The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The
`tensors` usually specify the model logits, which are transferred back from
@@ -254,7 +254,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
sending tensors from TPU to CPU. To reduce the overhead, try reducing the
size of the tensors. The `tensors` are concatenated along their major (batch)
dimension, and so must be >= rank 1. The `host_call` is useful for writing
- summaries with @{tf.contrib.summary.create_file_writer}.
+ summaries with `tf.contrib.summary.create_file_writer`.
"""
def __new__(cls,
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
index d9c77a3ea1..e75a09492e 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
@@ -765,9 +765,8 @@ class _PartitionedInfeedQueue(InfeedQueue):
zip(per_host_sharded_inputs[replica_index], inputs_part_dims_flat)
]
- for core_index in xrange(self._device_assignment.num_cores_per_replica):
+ for logical_core in xrange(self._device_assignment.num_cores_per_replica):
# Places different partitions to different logic cores.
- logical_core = self._get_logical_core(core_index)
replica_id = self._device_assignment.lookup_replicas(
self._host_id, logical_core)[replica_index]
ordinal = self._device_assignment.tpu_ordinal(
@@ -784,7 +783,7 @@ class _PartitionedInfeedQueue(InfeedQueue):
inputs=infeed_inputs,
shapes=[x.shape for x in infeed_inputs],
name="enqueue/replica_{0}/input_{1}".format(
- replica_index, core_index),
+ replica_index, logical_core),
device_ordinal=ordinal))
return per_host_enqueue_ops
@@ -890,20 +889,3 @@ class _PartitionedInfeedQueue(InfeedQueue):
return nest.map_structure_up_to(
dequeues, self._tag_sharding_attribute_for_dequeued_tensor, dequeues,
dims)
-
- def _get_logical_core(self, core_index):
- """Maps the core index to the 3D coordinate within replica.
-
- The lowest dimension number in computation_shape is the slowest varying
- dimension (most major).
-
- Args:
- core_index: An integer represents the core index within replcia.
-
- Returns:
- A tuple with three integers which represents the 3D coordinate.
- """
- computation_shape = self._device_assignment.computation_shape
- return (core_index // (computation_shape[1] * computation_shape[2]),
- core_index % (computation_shape[1] * computation_shape[2]) //
- computation_shape[2], core_index % computation_shape[2])
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_function.py b/tensorflow/contrib/tpu/python/tpu/tpu_function.py
index de16e3b157..0c7a38dbbb 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_function.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_function.py
@@ -63,10 +63,9 @@ def check_function_argument_count(func, input_arity, infeed_queue):
"""Validate the number of input arguments to a tpu function.
Args:
- func: the Python function that will be called to generate the body
- of a TPUFunction.
- input_arity: the number of explicit arguments supplied by the
- caller.
+ func: the Python function that will be called to generate the body of an XLA
+ computation graph.
+ input_arity: the number of explicit arguments supplied by the caller.
infeed_queue: if not None, the infeed queue that will supply
additional arguments to the function.
@@ -103,4 +102,3 @@ def check_function_argument_count(func, input_arity, infeed_queue):
# Since there are varargs, func can accept any number of arguments
# greater than the minimum.
return None
-
diff --git a/tensorflow/contrib/tpu/utils/BUILD b/tensorflow/contrib/tpu/utils/BUILD
new file mode 100644
index 0000000000..c27b737287
--- /dev/null
+++ b/tensorflow/contrib/tpu/utils/BUILD
@@ -0,0 +1,30 @@
+# Description: Utilities for TPU Operations
+
+licenses(["notice"]) # Apache 2.0
+
+cc_library(
+ name = "tpu_embedding_optimization_parameters_utils",
+ srcs = ["tpu_embedding_optimization_parameters_utils.cc"],
+ hdrs = ["tpu_embedding_optimization_parameters_utils.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/contrib/tpu/proto:optimization_parameters_proto_cc",
+ "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:lib_proto_parsing",
+ "@com_google_absl//absl/base",
+ ],
+)
+
+cc_library(
+ name = "tpu_embedding_output_layout_utils",
+ srcs = ["tpu_embedding_output_layout_utils.cc"],
+ hdrs = ["tpu_embedding_output_layout_utils.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc",
+ "//tensorflow/contrib/tpu/proto:tpu_embedding_output_layout_proto_cc",
+ "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:lib_proto_parsing",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
diff --git a/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.cc b/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.cc
new file mode 100644
index 0000000000..76cb5531cd
--- /dev/null
+++ b/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.cc
@@ -0,0 +1,255 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace tpu {
+
+string GetOptimizationAlgorithmName(OptimizationAlgorithm alg) {
+ switch (alg) {
+ case OptimizationAlgorithm::kAdagrad:
+ return "Adagrad";
+ case OptimizationAlgorithm::kStochasticGradientDescent:
+ return "StochasticGradientDescent";
+ case OptimizationAlgorithm::kFtrl:
+ return "FTRL";
+ case OptimizationAlgorithm::kAdam:
+ return "ADAM";
+ case OptimizationAlgorithm::kMomentum:
+ return "Momentum";
+ case OptimizationAlgorithm::kRmsProp:
+ return "RMSProp";
+ case OptimizationAlgorithm::kCenteredRmsProp:
+ return "CenteredRMSProp";
+ case OptimizationAlgorithm::kMdlAdagradLight:
+ return "MDLAdagradLight";
+ case OptimizationAlgorithm::kAdadelta:
+ return "Adadelta";
+ case OptimizationAlgorithm::kProximalAdagrad:
+ return "ProximalAdagrad";
+ case OptimizationAlgorithm::PARAMETERS_NOT_SET:
+ return "*** Not set ***";
+ }
+}
+
+string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) {
+ switch (alg) {
+ case OptimizationAlgorithm::kAdagrad:
+ return "Adagrad";
+ case OptimizationAlgorithm::kStochasticGradientDescent:
+ return "stochastic gradient descent";
+ case OptimizationAlgorithm::kFtrl:
+ return "FTRL";
+ case OptimizationAlgorithm::kAdam:
+ return "ADAM";
+ case OptimizationAlgorithm::kMomentum:
+ return "Momentum";
+ case OptimizationAlgorithm::kRmsProp:
+ return "RMSProp";
+ case OptimizationAlgorithm::kCenteredRmsProp:
+ return "centered RMSProp";
+ case OptimizationAlgorithm::kMdlAdagradLight:
+ return "MDL Adagrad Light";
+ case OptimizationAlgorithm::kAdadelta:
+ return "Adadelta";
+ case OptimizationAlgorithm::kProximalAdagrad:
+ return "proximal Adagrad";
+ case OptimizationAlgorithm::PARAMETERS_NOT_SET:
+ return "unknown (not specified)";
+ }
+}
+
+// Returns the number of optimization parameter vectors used by the optimization
+// algorithm, excluding the weights themselves and assuming no gradient
+// accumulation.
+Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int* count) {
+ switch (alg) {
+ case OptimizationAlgorithm::kAdagrad:
+ *count = 1;
+ return Status::OK();
+ case OptimizationAlgorithm::kStochasticGradientDescent:
+ *count = 0;
+ return Status::OK();
+ case OptimizationAlgorithm::kFtrl:
+ *count = 2;
+ return Status::OK();
+ case OptimizationAlgorithm::kAdam:
+ *count = 2;
+ return Status::OK();
+ case OptimizationAlgorithm::kMomentum:
+ *count = 1;
+ return Status::OK();
+ case OptimizationAlgorithm::kRmsProp:
+ *count = 2;
+ return Status::OK();
+ case OptimizationAlgorithm::kCenteredRmsProp:
+ *count = 3;
+ return Status::OK();
+ case OptimizationAlgorithm::kMdlAdagradLight:
+ *count = 3;
+ return Status::OK();
+ case OptimizationAlgorithm::kAdadelta:
+ *count = 2;
+ return Status::OK();
+ case OptimizationAlgorithm::kProximalAdagrad:
+ *count = 1;
+ return Status::OK();
+ case OptimizationAlgorithm::PARAMETERS_NOT_SET:
+ return errors::InvalidArgument("No optimization algorithm specified");
+ }
+}
+
+Status GetGradientAccumulationSupport(OptimizationAlgorithm alg,
+ GradientAccumulationSupport* support) {
+ switch (alg) {
+ case OptimizationAlgorithm::kAdagrad:
+ *support = GradientAccumulationSupport::kSupported;
+ return Status::OK();
+ case OptimizationAlgorithm::kStochasticGradientDescent:
+ *support = GradientAccumulationSupport::kUnnecessary;
+ return Status::OK();
+ default: {
+ int auxiliary_parameter_count;
+ TF_RETURN_IF_ERROR(
+ GetBaseAuxiliaryParameterCount(alg, &auxiliary_parameter_count));
+ *support = auxiliary_parameter_count + 1 <= kMaxAuxiliaryParameterCount
+ ? GradientAccumulationSupport::kSupported
+ : GradientAccumulationSupport::kNotSupported;
+ return Status::OK();
+ }
+ }
+}
+namespace {
+// Make a normal state variable specification.
+StateVariableSpecification MakeStandardStateVariableSpecification(
+ const string& name) {
+ StateVariableSpecification result;
+ result.set_name(name);
+ result.mutable_user_defined();
+ return result;
+}
+} // namespace
+
+Status GetOptimizationAlgorithmStateVariables(
+ OptimizationAlgorithm alg, bool use_gradient_accumulation,
+ std::vector<StateVariableSpecification>* state_variables) {
+ // The first parameter set is always the weights themselves.
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("parameters"));
+ // The order of the returned parameters needs to match the offsets used by
+ // the algorithm implementations in test_util.cc and
+ // address_handler_program_creator.cc.
+ switch (alg) {
+ case OptimizationAlgorithm::kAdagrad: {
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("accumulators"));
+ break;
+ }
+ case OptimizationAlgorithm::kStochasticGradientDescent: {
+ // None.
+ break;
+ }
+ case OptimizationAlgorithm::kFtrl: {
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("accumulators"));
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("linears"));
+ break;
+ }
+ case OptimizationAlgorithm::kAdam: {
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("momenta"));
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("velocities"));
+ break;
+ }
+ case OptimizationAlgorithm::kMomentum: {
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("momenta"));
+ break;
+ }
+ case OptimizationAlgorithm::kRmsProp: {
+ state_variables->push_back(MakeStandardStateVariableSpecification("ms"));
+ state_variables->push_back(MakeStandardStateVariableSpecification("mom"));
+ break;
+ }
+ case OptimizationAlgorithm::kCenteredRmsProp: {
+ state_variables->push_back(MakeStandardStateVariableSpecification("ms"));
+ state_variables->push_back(MakeStandardStateVariableSpecification("mom"));
+ state_variables->push_back(MakeStandardStateVariableSpecification("mg"));
+ break;
+ }
+ case OptimizationAlgorithm::kMdlAdagradLight: {
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("accumulators"));
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("weights"));
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("benefits"));
+ break;
+ }
+ case OptimizationAlgorithm::kAdadelta: {
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("accumulators"));
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("updates"));
+ break;
+ }
+ case OptimizationAlgorithm::kProximalAdagrad: {
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("accumulators"));
+ break;
+ }
+ case OptimizationAlgorithm::PARAMETERS_NOT_SET: {
+ return errors::InvalidArgument("No optimization algorithm specified");
+ }
+ }
+ // This needs to be last so that the save/restore ops do not need to know
+ // about gradient accumulation.
+ if (use_gradient_accumulation) {
+ StateVariableSpecification gradient_acc;
+ gradient_acc.set_name("gradient_accumulators");
+ gradient_acc.mutable_fill_with_constant()->set_initial_value(
+ kGradientAccumulatorInitialValue);
+ state_variables->push_back(std::move(gradient_acc));
+ }
+ if (state_variables->size() > kMaxAuxiliaryParameterCount + 1) {
+ return errors::InvalidArgument(
+ "Optimization algorithm", GetOptimizationAlgorithmName(alg),
+ "does not support gradient accumulation because it "
+ "already has too many other accumulators");
+ }
+ return Status::OK();
+} // namespace tpu
+
+std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms() {
+ return {
+ OptimizationAlgorithm::kAdagrad,
+ OptimizationAlgorithm::kStochasticGradientDescent,
+ OptimizationAlgorithm::kFtrl,
+ OptimizationAlgorithm::kAdam,
+ OptimizationAlgorithm::kMomentum,
+ OptimizationAlgorithm::kRmsProp,
+ OptimizationAlgorithm::kCenteredRmsProp,
+ OptimizationAlgorithm::kMdlAdagradLight,
+ OptimizationAlgorithm::kAdadelta,
+ OptimizationAlgorithm::kProximalAdagrad,
+ };
+}
+
+} // namespace tpu
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.h b/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.h
new file mode 100644
index 0000000000..81d50264ed
--- /dev/null
+++ b/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.h
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TPU_UTILS_TPU_EMBEDDING_OPTIMIZATION_PARAMETERS_UTILS_H_
+#define TENSORFLOW_CONTRIB_TPU_UTILS_TPU_EMBEDDING_OPTIMIZATION_PARAMETERS_UTILS_H_
+
+#include <string>
+#include "absl/base/casts.h"
+#include "tensorflow/contrib/tpu/proto/optimization_parameters.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace tpu {
+
+using OptimizationAlgorithm = OptimizationParameters::ParametersCase;
+
+// Returns the name of the optimization algorithm.
+string GetOptimizationAlgorithmName(OptimizationAlgorithm alg);
+
+// Returns a user-friendly name for the optimization algorithm.
+string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg);
+
+// Returns all supported optimization algorithms.
+std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms();
+
+enum class GradientAccumulationSupport {
+ // Accumulation cannot be used with this optimizer.
+ kNotSupported,
+
+ // Accumulation is unnecessary because optimizer application is commutative.
+ kUnnecessary,
+
+ // Accumulation is allowed and changes optimizer behavior.
+ kSupported,
+};
+
+// Returns the number of optimization parameter vectors used by the optimization
+// algorithm, excluding the weights themselves and assuming no gradient
+// accumulation.
+Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int *count);
+
+// Returns whether (and how) an optimization algorithm supports gradient
+// accumulation.
+Status GetGradientAccumulationSupport(OptimizationAlgorithm alg,
+ GradientAccumulationSupport *support);
+
+// Returns the parameter specifications for the optimization algorithm (the main
+// parameters first, followed by any auxiliary parameters such as Adagrad
+// accumulators).
+Status GetOptimizationAlgorithmStateVariables(
+ OptimizationAlgorithm alg, bool use_gradient_accumulation,
+ std::vector<StateVariableSpecification> *state_variables);
+
+// Maximum value of auxiliar_parameter_count for any optimization algorithm.
+static constexpr int kMaxAuxiliaryParameterCount = 3;
+
+// Fill value for gradient accumulators. This is a denormal so that it will be
+// flushed to zero on the current TPU platforms and needs to continue to have
+// the following properties in the future:
+//
+// 1. Does not have the same bit pattern as a zero and can be distinguished from
+// it using integer operations.
+// 2. Treated as zero by floating-point arithmetic operations (at least addition
+// and subtraction).
+// 3. Cannot be produced by any floating-point arithmetic operation, including
+// those involving itself.
+//
+// It does not need to compare equal or not equal to zero in floating point. We
+// need to use a non-zero value here because some optimization algorithms are
+// not no-ops on zero gradients, so we need to distinguish an accumulated
+// gradient of zero from one that has been cleared after its gradients have
+// already been applied to the parameters and accumulators.
+const float kGradientAccumulatorInitialValue = absl::bit_cast<float, uint32>(1);
+
+} // namespace tpu
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_TPU_UTILS_TPU_EMBEDDING_OPTIMIZATION_PARAMETERS_UTILS_H_
diff --git a/tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.cc b/tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.cc
new file mode 100644
index 0000000000..8480ec4b8b
--- /dev/null
+++ b/tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.cc
@@ -0,0 +1,98 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.h"
+#include "tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace tpu {
+
+void AddDefaultEmbeddingOutputLayoutIfNeeded(
+ TPUEmbeddingConfiguration* config) {
+ if (config->has_output_layout()) {
+ // Model or previous step has already filled this in.
+ return;
+ }
+
+ TPUEmbeddingOutputLayout* layout = config->mutable_output_layout();
+ // Create output tensors.
+ for (const auto& table : config->table_descriptor()) {
+ TPUEmbeddingOutputLayout::EmbeddingOutputTensor* output =
+ layout->add_output();
+ TPUEmbeddingOutputLayout::TwoDOutputTensor* two_d = output->mutable_two_d();
+ two_d->set_dim1_size(table.dimension());
+ two_d->set_dim0_size_per_sample(table.num_features());
+ }
+
+ // Create table output locations.
+ for (int table_id = 0; table_id < config->table_descriptor_size();
+ ++table_id) {
+ TPUEmbeddingOutputLayout::TableDescriptor* output_table =
+ layout->add_table();
+ const auto& table = config->table_descriptor(table_id);
+ for (int feature_index = 0; feature_index < table.num_features();
+ ++feature_index) {
+ TPUEmbeddingOutputLayout::FeatureDescriptor* output_feature =
+ output_table->add_feature();
+ TPUEmbeddingOutputLayout::OutputLocation* output_location =
+ output_feature->add_output_location();
+ output_location->set_tensor_index(table_id);
+ output_location->set_dim0_offset(feature_index);
+ output_location->set_dim1_offset(0);
+ }
+ }
+}
+
+Status ComputeOutputTensorShapes(const TPUEmbeddingConfiguration& config,
+ std::vector<TensorShapeProto>* shapes) {
+ if (!config.has_output_layout()) {
+ return errors::InvalidArgument(
+ "TPUEmbeddingConfiguration is missing output layout.");
+ }
+ const TPUEmbeddingOutputLayout& layout = config.output_layout();
+ int batch_size = config.batch_size_per_tensor_core();
+
+ for (int i = 0; i < layout.output_size(); ++i) {
+ const auto& output = layout.output(i);
+ TensorShapeProto shape;
+ switch (output.output_format_case()) {
+ case TPUEmbeddingOutputLayout::EmbeddingOutputTensor::OutputFormatCase::
+ kTwoD: {
+ auto* dim0 = shape.add_dim();
+ dim0->set_size(output.two_d().dim0_size_per_sample() * batch_size);
+ auto* dim1 = shape.add_dim();
+ dim1->set_size(output.two_d().dim1_size());
+ break;
+ }
+ case TPUEmbeddingOutputLayout::EmbeddingOutputTensor::OutputFormatCase::
+ OUTPUT_FORMAT_NOT_SET: {
+ return errors::InvalidArgument(
+ "Output layout in TPUEmbeddingConfiguration has unset embedding "
+ "output tensor format.");
+ }
+ default: {
+ return errors::InvalidArgument(
+ "Output layout in TPUEmbeddingConfiguration has invalid or "
+ "unhandled embedding output tensor format.");
+ }
+ }
+ shapes->push_back(shape);
+ }
+ return Status::OK();
+}
+
+} // namespace tpu
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.h b/tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.h
new file mode 100644
index 0000000000..c10fbeeff2
--- /dev/null
+++ b/tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.h
@@ -0,0 +1,38 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TPU_UTILS_TPU_EMBEDDING_OUTPUT_LAYOUT_UTILS_H_
+#define TENSORFLOW_CONTRIB_TPU_UTILS_TPU_EMBEDDING_OUTPUT_LAYOUT_UTILS_H_
+
+#include "tensorflow/contrib/tpu/proto/tpu_embedding_configuration.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace tpu {
+
+// Creates a default output layout for compatibility if none was provided by the
+// model.
+void AddDefaultEmbeddingOutputLayoutIfNeeded(TPUEmbeddingConfiguration* config);
+
+// Computes the shape of the output tensors from an output layout.
+Status ComputeOutputTensorShapes(
+ const TPUEmbeddingConfiguration& config,
+ std::vector<tensorflow::TensorShapeProto>* shapes);
+
+} // namespace tpu
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_TPU_UTILS_TPU_EMBEDDING_OUTPUT_LAYOUT_UTILS_H_
diff --git a/tensorflow/contrib/training/python/training/device_setter_test.py b/tensorflow/contrib/training/python/training/device_setter_test.py
index 20746d911c..3bb2dce83d 100644
--- a/tensorflow/contrib/training/python/training/device_setter_test.py
+++ b/tensorflow/contrib/training/python/training/device_setter_test.py
@@ -98,10 +98,10 @@ class GreedyLoadBalancingStrategyTest(test.TestCase):
cluster=_CLUSTER_SPEC,
ps_strategy=device_setter_lib.GreedyLoadBalancingStrategy(
2, device_setter_lib.byte_size_load_fn))):
- u = variables.Variable(array_ops.zeros([2, 2]))
- v = variables.Variable(array_ops.zeros([2, 1]))
- w = variables.Variable(array_ops.zeros([2, 2]))
- x = variables.Variable(array_ops.zeros([1, 3]))
+ u = variables.VariableV1(array_ops.zeros([2, 2]))
+ v = variables.VariableV1(array_ops.zeros([2, 1]))
+ w = variables.VariableV1(array_ops.zeros([2, 2]))
+ x = variables.VariableV1(array_ops.zeros([1, 3]))
a = v + w
self.assertDeviceEqual("/job:ps/task:0", u.device)
self.assertDeviceEqual("/job:ps/task:0", u.initializer.device)
diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
index f46d03209c..8896a95327 100644
--- a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
+++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
@@ -29,7 +29,7 @@ from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.util import nest as tf_nest
-class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset):
+class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that prepends a queue to another `Dataset`.
A vector of handles to the queue is returned as the first component of
@@ -39,7 +39,7 @@ class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset):
def __init__(self, input_dataset, batch_size, padded_shapes, padding_values):
"""Initialize `PrependFromQueueAndPaddedBatchDataset`."""
- super(_PrependFromQueueAndPaddedBatchDataset, self).__init__()
+ super(_PrependFromQueueAndPaddedBatchDataset, self).__init__(input_dataset)
if sparse.any_sparse(input_dataset.output_classes):
raise TypeError(
"Batching of padded sparse tensors is not currently supported")
diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc
index 3cb5e61fac..2784bf124c 100644
--- a/tensorflow/contrib/verbs/rdma_mgr.cc
+++ b/tensorflow/contrib/verbs/rdma_mgr.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include <vector>
#include "tensorflow/contrib/verbs/grpc_verbs_client.h"
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
-#include "tensorflow/core/common_runtime/bfc_allocator.h"
#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/pool_allocator.h"
@@ -29,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/framework/allocator_registry.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
@@ -256,74 +256,41 @@ void MRDeleter(ibv_mr* mr) {
}
}
-// TODO(byronyi): remove this class and its registration when the default
-// cpu_allocator() returns visitable allocator, or cpu_allocator() is no
-// longer in use.
-class BFCRdmaAllocator : public BFCAllocator {
- public:
- BFCRdmaAllocator()
- : BFCAllocator(new BasicCPUAllocator(port::kNUMANoAffinity), 1LL << 36,
- true, "cpu_rdma_bfc") {}
-};
-class BFCRdmaAllocatorFactory : public AllocatorFactory {
- public:
- Allocator* CreateAllocator() { return new BFCRdmaAllocator; }
-
- SubAllocator* CreateSubAllocator(int numa_node) {
- return new BasicCPUAllocator(numa_node);
- }
-};
-
-REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocatorFactory);
-
void RdmaMgr::InitAllocators() {
- RdmaMemoryMgr::Singleton().pd_ = rdma_adapter_->pd_;
+ static std::once_flag flag;
+ std::call_once(
+ flag, [this]() { RdmaMemoryMgr::Singleton().pd_ = rdma_adapter_->pd_; });
+}
- Allocator* allocators[] = {
-#if GOOGLE_CUDA
- GPUProcessState::singleton()->GetCUDAHostAllocator(0),
-#endif // GOOGLE_CUDA
- ProcessState::singleton()->GetCPUAllocator(0),
- cpu_allocator(),
+/*static*/ void RdmaMgr::RegMemVisitors() {
+ SubAllocator::Visitor alloc_visitor = [](void* ptr, int numa_node,
+ size_t num_bytes) {
+ RdmaMemoryMgr::Singleton().InsertMemoryRegion(
+ ptr, num_bytes, strings::StrCat("CPU:", numa_node));
+ };
+ SubAllocator::Visitor free_visitor = [](void* ptr, int numa_node,
+ size_t num_bytes) {
+ RdmaMemoryMgr::Singleton().EvictMemoryRegion(ptr, num_bytes);
};
- using namespace std::placeholders;
-
- std::set<Allocator*> instrumented_;
-
- // Host memory allocators
- for (Allocator* allocator : allocators) {
- VisitableAllocator::Visitor alloc_visitor =
- std::bind(&RdmaMemoryMgr::InsertMemoryRegion,
- &RdmaMemoryMgr::Singleton(), _1, _2, allocator->Name());
- VisitableAllocator::Visitor free_visitor = std::bind(
- &RdmaMemoryMgr::EvictMemoryRegion, &RdmaMemoryMgr::Singleton(), _1, _2);
-
- auto* visitable_allocator = dynamic_cast<VisitableAllocator*>(allocator);
- CHECK(visitable_allocator)
- << "is not visitable for instrumentation" << allocator->Name();
- // Make sure we don't instrument the same allocator twice
- if (instrumented_.find(allocator) == std::end(instrumented_)) {
- visitable_allocator->AddAllocVisitor(alloc_visitor);
- visitable_allocator->AddFreeVisitor(free_visitor);
- instrumented_.insert(allocator);
- LOG(INFO) << "Instrumenting CPU allocator " << allocator->Name();
- }
- }
+ ProcessState::singleton()->AddCPUAllocVisitor(alloc_visitor);
+ ProcessState::singleton()->AddCPUFreeVisitor(free_visitor);
#if GOOGLE_CUDA
if (IsGDRAvailable()) {
// Note we don't free allocated GPU memory so there is no free visitor
int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + 1;
- char buf[8];
- sprintf(buf, "gpu");
- VisitableAllocator::Visitor cuda_alloc_visitor =
- std::bind(&RdmaMemoryMgr::InsertMemoryRegion,
- &RdmaMemoryMgr::Singleton(), _1, _2, std::string(buf));
-
+ SubAllocator::Visitor cuda_alloc_visitor = [](void* ptr, int gpu_id,
+ size_t num_bytes) {
+ RdmaMemoryMgr::Singleton().InsertMemoryRegion(
+ ptr, num_bytes, strings::StrCat("GPU:", gpu_id));
+ };
GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id,
cuda_alloc_visitor);
+ GPUProcessState::singleton()->AddCUDAHostAllocVisitor(bus_id,
+ alloc_visitor);
+ GPUProcessState::singleton()->AddCUDAHostFreeVisitor(bus_id, free_visitor);
LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
}
#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/verbs/rdma_mgr.h b/tensorflow/contrib/verbs/rdma_mgr.h
index 9fffc335bb..74b92cc9a6 100644
--- a/tensorflow/contrib/verbs/rdma_mgr.h
+++ b/tensorflow/contrib/verbs/rdma_mgr.h
@@ -39,6 +39,7 @@ class RdmaMgr {
void SetupChannels();
bool ConnectivityCheck();
void InitAllocators();
+ static void RegMemVisitors();
const string& local_worker() { return local_worker_; }
private:
diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc
index 1a0b5028fe..5b72b1604a 100644
--- a/tensorflow/contrib/verbs/verbs_server_lib.cc
+++ b/tensorflow/contrib/verbs/verbs_server_lib.cc
@@ -76,8 +76,13 @@ Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def,
return Status::OK();
}
+namespace {
+std::once_flag reg_mem_visitors_call;
+} // namespace
+
Status VerbsServer::Init(ServiceInitFunction service_func,
RendezvousMgrCreationFunction rendezvous_mgr_func) {
+ std::call_once(reg_mem_visitors_call, []() { RdmaMgr::RegMemVisitors(); });
Status s = GrpcServer::Init(service_func, rendezvous_mgr_func);
{
mutex_lock l(mu_);
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 55715bb3a6..50fe308b73 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -85,11 +85,12 @@ load(
"tf_cc_tests",
"tf_copts",
"tf_cuda_library",
+ "tf_features_nomodules_if_android",
"tf_gen_op_libs",
"tf_generate_proto_text_sources",
"tf_genrule_cmd_append_to_srcs",
"tf_opts_nortti_if_android",
- "tf_features_nomodules_if_android",
+ "transitive_hdrs",
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl")
load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu")
@@ -120,16 +121,16 @@ load(
"tf_additional_libdevice_srcs",
"tf_additional_minimal_lib_srcs",
"tf_additional_mpi_lib_defines",
- "tf_additional_proto_hdrs",
"tf_additional_proto_compiler_hdrs",
+ "tf_additional_proto_hdrs",
"tf_additional_proto_srcs",
"tf_additional_test_deps",
"tf_additional_test_srcs",
"tf_additional_verbs_lib_defines",
"tf_jspb_proto_library",
"tf_kernel_tests_linkstatic",
- "tf_lib_proto_parsing_deps",
"tf_lib_proto_compiler_deps",
+ "tf_lib_proto_parsing_deps",
"tf_nano_proto_library",
"tf_platform_hdrs",
"tf_platform_srcs",
@@ -143,10 +144,12 @@ load(
)
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
+ "if_dynamic_kernels",
"if_static",
"tf_cuda_tests_tags",
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
+load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library")
load(
"//third_party/mkl:build_defs.bzl",
@@ -178,7 +181,6 @@ COMMON_PROTO_SRCS = [
"framework/iterator.proto",
"framework/kernel_def.proto",
"framework/log_memory.proto",
- "framework/model.proto",
"framework/node_def.proto",
"framework/op_def.proto",
"framework/reader_base.proto",
@@ -706,14 +708,11 @@ cc_library(
cc_library(
name = "feature_util",
srcs = ["example/feature_util.cc"],
- hdrs = [
- "example/feature_util.h",
- "platform/types.h",
- ],
+ hdrs = ["example/feature_util.h"],
visibility = ["//visibility:public"],
deps = [
":core_stringpiece",
- ":platform_protobuf",
+ ":lib_proto_parsing",
":protos_all_cc",
],
)
@@ -842,7 +841,6 @@ tf_cuda_library(
"framework/log_memory.h",
"framework/lookup_interface.h",
"framework/memory_types.h",
- "framework/model.h",
"framework/node_def_builder.h",
"framework/node_def_util.h",
"framework/numeric_op.h",
@@ -1041,6 +1039,7 @@ tf_gen_op_libs(
"dataset_ops",
"decode_proto_ops",
"encode_proto_ops",
+ "experimental_dataset_ops",
"function_ops",
"functional_ops",
"image_ops",
@@ -1057,7 +1056,6 @@ tf_gen_op_libs(
"random_grad",
"random_ops",
"remote_fused_graph_ops",
- "resource_variable_ops",
"rpc_ops",
"scoped_allocator_ops",
"sdca_ops",
@@ -1068,7 +1066,6 @@ tf_gen_op_libs(
"spectral_ops",
"state_ops",
"stateless_random_ops",
- "string_ops",
"summary_ops",
"training_ops",
],
@@ -1076,6 +1073,13 @@ tf_gen_op_libs(
tf_gen_op_libs(
op_lib_names = [
+ "string_ops",
+ ],
+ deps = ["@com_google_absl//absl/strings"],
+)
+
+tf_gen_op_libs(
+ op_lib_names = [
"array_ops",
],
deps = [":protos_all_cc"],
@@ -1093,6 +1097,14 @@ tf_gen_op_libs(
deps = ["//tensorflow/core/kernels:debug_ops"],
)
+tf_gen_op_libs(
+ is_external = False,
+ op_lib_names = [
+ "resource_variable_ops",
+ ],
+ deps = [":lib"],
+)
+
# And one for all user ops
cc_library(
name = "user_ops_op_lib",
@@ -1158,6 +1170,7 @@ cc_library(
":dataset_ops_op_lib",
":decode_proto_ops_op_lib",
":encode_proto_ops_op_lib",
+ ":experimental_dataset_ops_op_lib",
":function_ops_op_lib",
":functional_ops_op_lib",
":image_ops_op_lib",
@@ -1287,8 +1300,8 @@ cc_library(
# This includes implementations of all kernels built into TensorFlow.
cc_library(
- name = "all_kernels",
- visibility = ["//visibility:public"],
+ name = "all_kernels_statically_linked",
+ visibility = ["//visibility:private"],
deps = [
"//tensorflow/core/kernels:array",
"//tensorflow/core/kernels:audio",
@@ -1331,6 +1344,7 @@ cc_library(
"//tensorflow/core/kernels:rpc_op",
"//tensorflow/core/kernels:scoped_allocator_ops",
"//tensorflow/core/kernels:sdca_ops",
+ "//tensorflow/core/kernels:searchsorted_op",
"//tensorflow/core/kernels:set_kernels",
"//tensorflow/core/kernels:sparse",
"//tensorflow/core/kernels:state",
@@ -1356,6 +1370,7 @@ cc_library(
"//tensorflow/core/kernels:mkl_pooling_ops",
"//tensorflow/core/kernels:mkl_relu_op",
"//tensorflow/core/kernels:mkl_reshape_op",
+ "//tensorflow/core/kernels:mkl_slice_op",
"//tensorflow/core/kernels:mkl_softmax_op",
"//tensorflow/core/kernels:mkl_transpose_op",
"//tensorflow/core/kernels:mkl_tfconv_op",
@@ -1366,6 +1381,15 @@ cc_library(
]),
)
+cc_library(
+ name = "all_kernels",
+ visibility = ["//visibility:public"],
+ deps = if_dynamic_kernels(
+ [],
+ otherwise = [":all_kernels_statically_linked"],
+ ),
+)
+
tf_cuda_library(
name = "tensorflow_opensource",
copts = tf_copts(),
@@ -2096,6 +2120,7 @@ cc_library(
deps = tf_additional_lib_deps() + [
"@com_google_absl//absl/strings",
"//third_party/eigen3",
+ "@com_google_absl//absl/base:core_headers",
"//tensorflow/core/platform/default/build_config:platformlib",
] + if_static([":lib_internal_impl"]),
)
@@ -2288,6 +2313,7 @@ cc_library(
deps = [
"//tensorflow/core/platform/default/build_config:jpeg",
"//tensorflow/core/platform/default/build_config:logging",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
],
)
@@ -2320,6 +2346,7 @@ cc_library(
deps = [
"//tensorflow/core/platform/default/build_config:gif",
"//tensorflow/core/platform/default/build_config:logging",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
],
)
@@ -2492,7 +2519,12 @@ tf_cuda_library(
cc_header_only_library(
name = "framework_internal_headers_lib",
- includes = ["../../external/com_google_absl"],
+ # Fully depend on external repositories, because identifying the headers
+ # is fragile.
+ extra_deps = [
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
deps = [
":lib",
":lib_internal",
@@ -2530,6 +2562,7 @@ tf_cuda_library(
"**/*test*",
"**/*main.cc",
"example/example_parser_configuration.*",
+ "example/feature_util.cc",
"util/reporter.cc",
"framework/fake_input.*",
"framework/op_gen_lib.*",
@@ -2559,6 +2592,7 @@ tf_cuda_library(
],
}),
deps = [
+ ":feature_util",
":lib",
":lib_internal",
":protos_all_proto_text",
@@ -2578,11 +2612,12 @@ tf_cuda_library(
cc_header_only_library(
name = "framework_headers_lib",
+ # Fully depend on external repositories, because identifying the headers
+ # is fragile.
extra_deps = [
- # ABSL headers get dropped, so we add them back here.
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
- includes = ["../../external/com_google_absl"],
visibility = ["//visibility:public"],
deps = [
":framework",
@@ -2592,7 +2627,12 @@ cc_header_only_library(
cc_header_only_library(
name = "stream_executor_headers_lib",
- includes = ["../../external/com_google_absl"],
+ # Fully depend on external repositories, because identifying the headers
+ # is fragile.
+ extra_deps = [
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
visibility = ["//visibility:public"],
deps = [
":stream_executor",
@@ -2783,8 +2823,6 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/stats_publisher_interface.h",
"common_runtime/step_stats_collector.h",
"common_runtime/threadpool_device.h",
- "common_runtime/tracing_device.h",
- "common_runtime/visitable_allocator.h",
"common_runtime/process_state.h",
"common_runtime/pool_allocator.h",
"graph/gradients.h",
@@ -2971,7 +3009,7 @@ tf_cuda_library(
"platform/device_tracer.h",
],
copts = tf_copts(),
- cuda_deps = tf_additional_cupti_wrapper_deps() + tf_additional_device_tracer_cuda_deps(),
+ cuda_deps = if_cuda_is_configured(tf_additional_cupti_wrapper_deps() + tf_additional_device_tracer_cuda_deps()),
visibility = ["//visibility:private"],
deps = [
":core_cpu_internal",
@@ -2980,12 +3018,16 @@ tf_cuda_library(
] + tf_additional_device_tracer_deps(),
)
-cc_library(
- name = "session_ref",
- srcs = ["common_runtime/session_ref.cc"],
- hdrs = ["common_runtime/session_ref.h"],
- copts = tf_copts(),
- deps = [":core_cpu_base"],
+tf_proto_library_cc(
+ name = "replay_log_proto",
+ srcs = ["protobuf/replay_log.proto"],
+ cc_api_version = 2,
+ protodeps = [
+ ":master_proto",
+ ] + tf_additional_all_protos(),
+ visibility = [
+ "//tensorflow:internal",
+ ],
)
cc_library(
@@ -3789,6 +3831,7 @@ tf_cc_test_mkl(
"//tensorflow/core/kernels:mkl_pooling_ops",
"//tensorflow/core/kernels:mkl_relu_op",
"//tensorflow/core/kernels:mkl_reshape_op",
+ "//tensorflow/core/kernels:mkl_slice_op",
"//tensorflow/core/kernels:mkl_softmax_op",
"//tensorflow/core/kernels:mkl_tfconv_op",
]),
@@ -4722,6 +4765,18 @@ cc_library(
] + tf_additional_libdevice_deps(),
)
+transitive_hdrs(
+ name = "headers",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:stream_executor",
+ ],
+)
+
# -----------------------------------------------------------------------------
# Google-internal targets go here (must be at the end).
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt
new file mode 100644
index 0000000000..fa8fc96bb2
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalAssertNextDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt
new file mode 100644
index 0000000000..5fd88e7a0c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalCSVDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt
new file mode 100644
index 0000000000..ac1f9719fe
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt
@@ -0,0 +1,21 @@
+op {
+ graph_op_name: "ExperimentalDirectedInterleaveDataset"
+ in_arg {
+ name: "selector_input_dataset"
+ description: <<END
+A dataset of scalar `DT_INT64` elements that determines which of the
+`N` data inputs should produce the next output element.
+END
+ }
+ in_arg {
+ name: "data_input_datasets"
+ description: <<END
+`N` datasets with the same type that will be interleaved according to
+the values of `selector_input_dataset`.
+END
+ }
+ summary: <<END
+A substitute for `InterleaveDataset` on a fixed list of `N` datasets.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt
new file mode 100644
index 0000000000..66511eff60
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt
@@ -0,0 +1,58 @@
+op {
+ graph_op_name: "ExperimentalFunctionBufferingResource"
+ in_arg {
+ name: "string_arg"
+ description: <<END
+String argument to the function call.
+END
+ }
+ in_arg {
+ name: "target_device"
+ description: <<END
+Target device to execute the function on.
+END
+ }
+ out_arg {
+ name: "resource"
+ description: <<END
+Handle to the resource created.
+END
+ }
+ attr {
+ name: "shared_name"
+ description: <<END
+If non-empty, this resource will be shared under the given name across
+multiple sessions.
+END
+ }
+ attr {
+ name: "container"
+ description: <<END
+If non-empty, this resource is placed in the given container.
+Otherwise, a default container is used.
+END
+ }
+ attr {
+ name: "f"
+ description: <<END
+Function to be executed.
+END
+ }
+ attr {
+ name: "buffer_size"
+ description: <<END
+Size of the buffer.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ summary: <<END
+Creates a resource that fills up a buffer by making function calls.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt
new file mode 100644
index 0000000000..bf4b66b22b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt
@@ -0,0 +1,25 @@
+op {
+ graph_op_name: "ExperimentalFunctionBufferingResourceGetNext"
+ in_arg {
+ name: "function_buffer_resource"
+ description: <<END
+The FunctionBufferingResource handle.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A list of return values.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ summary: <<END
+Gets the next element from a FunctionBufferingResource.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt
new file mode 100644
index 0000000000..729718ddb3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "ExperimentalFunctionBufferingResourceReset"
+ in_arg {
+ name: "function_buffer_resource"
+ description: <<END
+The FunctionBufferingResource handle.
+END
+ }
+ summary: <<END
+Resets the FunctionBufferingResource.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt
new file mode 100644
index 0000000000..fe266c111f
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalIdentityIndexedDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt
new file mode 100644
index 0000000000..d42546516d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt
@@ -0,0 +1,8 @@
+op {
+ graph_op_name: "ExperimentalIgnoreErrorsDataset"
+ summary: <<END
+Creates a dataset that contains the elements of `input_dataset` ignoring errors.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt
new file mode 100644
index 0000000000..e285f87e10
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalIndexedDatasetGet"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt
new file mode 100644
index 0000000000..60c32473b5
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalIndexedDatasetMaterialize"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt
new file mode 100644
index 0000000000..b72b229e9a
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt
@@ -0,0 +1,8 @@
+op {
+ graph_op_name: "ExperimentalIteratorGetDevice"
+ summary: <<END
+Returns the name of the device on which `resource` has been placed.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt
new file mode 100644
index 0000000000..b38b23a51d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalLMDBDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt
new file mode 100644
index 0000000000..9676b9d284
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalMaterializedIndexDatasetHandle"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt
new file mode 100644
index 0000000000..d73b5bfda3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "ExperimentalThreadPoolDataset"
+ in_arg {
+ name: "thread_pool"
+ description: <<END
+A resource produced by the ThreadPoolHandle op.
+END
+ }
+ summary: <<END
+Creates a dataset that uses a custom thread pool to compute `input_dataset`.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt
new file mode 100644
index 0000000000..48bf93406c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt
@@ -0,0 +1,35 @@
+op {
+ graph_op_name: "ExperimentalThreadPoolHandle"
+ out_arg {
+ name: "handle"
+ description: <<END
+A resource that can be consumed by one or more ExperimentalThreadPoolDataset
+ops.
+END
+ }
+ attr {
+ name: "num_threads"
+ description: <<END
+The number of threads in the thread pool.
+END
+ }
+ attr {
+ name: "max_intra_op_parallelism"
+ description: <<END
+The maximum degree of parallelism to use within operations that execute on this
+threadpool.
+END
+ }
+ attr {
+ name: "display_name"
+ description: <<END
+A human-readable name for the threads that may be visible in some
+visualizations.
+threadpool.
+END
+ }
+ summary: <<END
+Creates a dataset that uses a custom thread pool to compute `input_dataset`.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt
new file mode 100644
index 0000000000..68ed797a0c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt
@@ -0,0 +1,8 @@
+op {
+ graph_op_name: "ExperimentalUniqueDataset"
+ summary: <<END
+Creates a dataset that contains the unique elements of `input_dataset`.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ExtractVolumePatches.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExtractVolumePatches.pbtxt
new file mode 100644
index 0000000000..3c8a455983
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExtractVolumePatches.pbtxt
@@ -0,0 +1,49 @@
+op {
+ graph_op_name: "ExtractVolumePatches"
+ in_arg {
+ name: "input"
+ description: <<END
+5-D Tensor with shape `[batch, in_planes, in_rows, in_cols, depth]`.
+END
+ }
+ out_arg {
+ name: "patches"
+ description: <<END
+5-D Tensor with shape `[batch, out_planes, out_rows, out_cols,
+ksize_planes * ksize_rows * ksize_cols * depth]` containing patches
+with size `ksize_planes x ksize_rows x ksize_cols x depth` vectorized
+in the "depth" dimension. Note `out_planes`, `out_rows` and `out_cols`
+are the dimensions of the output patches.
+END
+ }
+ attr {
+ name: "ksizes"
+ description: <<END
+The size of the sliding window for each dimension of `input`.
+END
+ }
+ attr {
+ name: "strides"
+ description: <<END
+1-D of length 5. How far the centers of two consecutive patches are in
+`input`. Must be: `[1, stride_planes, stride_rows, stride_cols, 1]`.
+END
+ }
+ attr {
+ name: "padding"
+ description: <<END
+The type of padding algorithm to use.
+
+We specify the size-related attributes as:
+
+```python
+ ksizes = [1, ksize_planes, ksize_rows, ksize_cols, 1]
+ strides = [1, stride_planes, strides_rows, strides_cols, 1]
+```
+END
+ }
+ summary: <<END
+Extract `patches` from `input` and put them in the "depth" output
+dimension. 3D extension of `extract_image_patches`.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt b/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt
index 40d7d371ca..7142a0e3f2 100644
--- a/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt
@@ -9,7 +9,7 @@ The lower regularized incomplete Gamma function is defined as:
where
-\\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\)
+\\(gamma(a, x) = \\int_{0}^{x} t^{a-1} exp(-t) dt\\)
is the lower incomplete Gamma function.
diff --git a/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt b/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt
new file mode 100644
index 0000000000..5ce825ae04
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt
@@ -0,0 +1,45 @@
+op {
+ graph_op_name: "LowerBound"
+ visibility: HIDDEN
+ in_arg {
+ name: "sorted_inputs"
+ description: <<END
+2-D Tensor where each row is ordered.
+END
+ }
+ in_arg {
+ name: "values"
+ description: <<END
+2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains
+the values that will be searched for in `sorted_search_values`.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A `Tensor` with the same shape as `values`. It contains the first scalar index
+into the last dimension where values can be inserted without changing the
+ordered property.
+END
+ }
+ summary: "Applies lower_bound(sorted_search_values, values) along each row."
+ description: <<END
+Each set of rows with the same index in (sorted_inputs, values) is treated
+independently. The resulting row is the equivalent of calling
+`np.searchsorted(sorted_inputs, values, side='left')`.
+
+The result is not a global index to the entire
+`Tensor`, but rather just the index in the last dimension.
+
+A 2-D example:
+ sorted_sequence = [[0, 3, 9, 9, 10],
+ [1, 2, 3, 4, 5]]
+ values = [[2, 4, 9],
+ [0, 2, 6]]
+
+ result = LowerBound(sorted_sequence, values)
+
+ result == [[1, 2, 2],
+ [0, 1, 5]]
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIterator.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIterator.pbtxt
new file mode 100644
index 0000000000..4b0a5d8f65
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIterator.pbtxt
@@ -0,0 +1,43 @@
+op {
+ graph_op_name: "MultiDeviceIterator"
+ out_arg {
+ name: "handle"
+ description: <<END
+Handle to the resource created.
+END
+ }
+ attr {
+ name: "devices"
+ description: <<END
+A list of devices the iterator works across.
+END
+ }
+ attr {
+ name: "shared_name"
+ description: <<END
+If non-empty, this resource will be shared under the given name
+across multiple sessions.
+END
+ }
+ attr {
+ name: "container"
+ description: <<END
+If non-empty, this resource is placed in the given container.
+Otherwise, a default container is used.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ attr {
+ name: "output_shapes"
+ description: <<END
+The list of shapes being produced.
+END
+ }
+ summary: "Creates a MultiDeviceIterator resource."
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorFromStringHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorFromStringHandle.pbtxt
new file mode 100644
index 0000000000..adaacd8ab7
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorFromStringHandle.pbtxt
@@ -0,0 +1,29 @@
+op {
+ graph_op_name: "MultiDeviceIteratorFromStringHandle"
+ in_arg {
+ name: "string_handle"
+ description: <<END
+String representing the resource.
+END
+ }
+ out_arg {
+ name: "multi_device_iterator"
+ description: <<END
+A MultiDeviceIterator resource.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ attr {
+ name: "output_shapes"
+ description: <<END
+The list of shapes being produced.
+END
+ }
+ summary: "Generates a MultiDeviceIterator resource from its provided string handle."
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorGetNextFromShard.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorGetNextFromShard.pbtxt
new file mode 100644
index 0000000000..f9be9188cc
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorGetNextFromShard.pbtxt
@@ -0,0 +1,41 @@
+op {
+ graph_op_name: "MultiDeviceIteratorGetNextFromShard"
+ in_arg {
+ name: "multi_device_iterator"
+ description: <<END
+A MultiDeviceIterator resource.
+END
+ }
+ in_arg {
+ name: "shard_num"
+ description: <<END
+Integer representing which shard to fetch data for.
+END
+ }
+ in_arg {
+ name: "incarnation_id"
+ description: <<END
+Which incarnation of the MultiDeviceIterator is running.
+END
+ }
+ out_arg {
+ name: "components"
+ description: <<END
+Result of the get_next on the dataset.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ attr {
+ name: "output_shapes"
+ description: <<END
+The list of shapes being produced.
+END
+ }
+ summary: "Gets next element for the provided shard number."
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorInit.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorInit.pbtxt
new file mode 100644
index 0000000000..6b54fa1307
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorInit.pbtxt
@@ -0,0 +1,30 @@
+op {
+ graph_op_name: "MultiDeviceIteratorInit"
+ in_arg {
+ name: "dataset"
+ description: <<END
+Dataset to be iterated upon.
+END
+ }
+ in_arg {
+ name: "multi_device_iterator"
+ description: <<END
+A MultiDeviceIteratorResource.
+END
+ }
+ in_arg {
+ name: "max_buffer_size"
+ description: <<END
+The maximum size of the host side per device buffer to keep.
+END
+ }
+ out_arg {
+ name: "incarnation_id"
+ description: <<END
+An int64 indicating which incarnation of the MultiDeviceIterator
+is running.
+END
+ }
+ summary: "Initializes the multi device iterator with the given dataset."
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorToStringHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorToStringHandle.pbtxt
new file mode 100644
index 0000000000..1f1fdf99b4
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorToStringHandle.pbtxt
@@ -0,0 +1,17 @@
+op {
+ graph_op_name: "MultiDeviceIteratorToStringHandle"
+ in_arg {
+ name: "multi_device_iterator"
+ description: <<END
+A MultiDeviceIterator resource.
+END
+ }
+ out_arg {
+ name: "string_handle"
+ description: <<END
+A string representing the resource.
+END
+ }
+ summary: "Produces a string handle for the given MultiDeviceIterator."
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt
new file mode 100644
index 0000000000..4cb8955dcb
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt
@@ -0,0 +1,19 @@
+op {
+ graph_op_name: "PrintV2"
+ in_arg {
+ name: "input"
+ description: <<END
+The string scalar to print.
+END
+ }
+ attr {
+ name: "output_stream"
+ description: <<END
+A string specifying the output stream or logging level to print to.
+END
+ }
+ summary: "Prints a string scalar."
+ description: <<END
+Prints a string scalar to the desired output_stream.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ReduceDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ReduceDataset.pbtxt
new file mode 100644
index 0000000000..08414b3e68
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ReduceDataset.pbtxt
@@ -0,0 +1,26 @@
+op {
+ visibility: HIDDEN
+ graph_op_name: "ReduceDataset"
+ in_arg {
+ name: "input_dataset"
+ description: <<END
+A variant tensor representing the input dataset.
+END
+ }
+ in_arg {
+ name: "initial_state"
+ description: <<END
+A nested structure of tensors, representing the initial state of the
+transformation.
+END
+ }
+ attr {
+ name: "f"
+ description: <<END
+A function that maps `(old_state, input_element)` to `new_state`. It must take
+two arguments and return a nested structures of tensors. The structure of
+`new_state` must match the structure of `initial_state`.
+END
+ }
+ summary: "Reduces the input dataset to a singleton using a reduce function."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt b/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt
new file mode 100644
index 0000000000..a82dae9e48
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt
@@ -0,0 +1,38 @@
+op {
+ graph_op_name: "StringFormat"
+ in_arg {
+ name: "inputs"
+ description: <<END
+The list of tensors to format into the placeholder string.
+END
+ }
+
+ out_arg {
+ name: "output"
+ description: <<END
+= The resulting string scalar.
+END
+ }
+ attr {
+ name: "template"
+ description: <<END
+A string, the template to format tensor summaries into.
+END
+ }
+ attr {
+ name: "placeholder"
+ description: <<END
+A string, at each placeholder in the template a subsequent tensor summary will be inserted.
+END
+ }
+ attr {
+ name: "summarize"
+ description: <<END
+When formatting the tensor summaries print the first and last summarize entries of each tensor dimension.
+END
+ }
+ summary: "Formats a string template using a list of tensors."
+ description: <<END
+Formats a string template using a list of tensors, pretty-printing tensor summaries.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt b/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt
index cc21ddc815..7d2fbcd00b 100644
--- a/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt
@@ -1,5 +1,15 @@
op {
graph_op_name: "StringLength"
+ attr {
+ name: "unit"
+ description: <<END
+The unit that is counted to compute string length. One of: `"BYTE"` (for
+the number of bytes in each string) or `"UTF8_CHAR"` (for the number of UTF-8
+encoded Unicode code points in each string). Results are undefined
+if `unit=UTF8_CHAR` and the `input` strings do not contain structurally
+valid UTF-8.
+END
+ }
in_arg {
name: "input"
description: <<END
diff --git a/tensorflow/core/api_def/base_api/api_def_UnicodeScript.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnicodeScript.pbtxt
new file mode 100644
index 0000000000..7898fe8d6b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_UnicodeScript.pbtxt
@@ -0,0 +1,28 @@
+op {
+ graph_op_name: "UnicodeScript"
+ endpoint {
+ name: "UnicodeScript"
+ }
+ in_arg {
+ name: "input"
+ description: <<END
+A Tensor of int32 Unicode code points.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A Tensor of int32 script codes corresponding to each input code point.
+END
+ }
+ summary: <<END
+Determine the script codes of a given tensor of Unicode integer code points.
+END
+ description: <<END
+This operation converts Unicode code points to script codes corresponding to
+each code point. Script codes correspond to International Components for
+Unicode (ICU) UScriptCode values. See http://icu-project.org/apiref/icu4c/uscript_8h.html.
+Returns -1 (USCRIPT_INVALID_CODE) for invalid codepoints. Output shape will
+match input shape.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt b/tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt
new file mode 100644
index 0000000000..0630f6e82a
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt
@@ -0,0 +1,45 @@
+op {
+ graph_op_name: "UpperBound"
+ visibility: HIDDEN
+ in_arg {
+ name: "sorted_inputs"
+ description: <<END
+2-D Tensor where each row is ordered.
+END
+ }
+ in_arg {
+ name: "values"
+ description: <<END
+2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains
+the values that will be searched for in `sorted_search_values`.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A `Tensor` with the same shape as `values`. It contains the last scalar index
+into the last dimension where values can be inserted without changing the
+ordered property.
+END
+ }
+ summary: "Applies upper_bound(sorted_search_values, values) along each row."
+ description: <<END
+Each set of rows with the same index in (sorted_inputs, values) is treated
+independently. The resulting row is the equivalent of calling
+`np.searchsorted(sorted_inputs, values, side='right')`.
+
+The result is not a global index to the entire
+`Tensor`, but rather just the index in the last dimension.
+
+A 2-D example:
+ sorted_sequence = [[0, 3, 9, 9, 10],
+ [1, 2, 3, 4, 5]]
+ values = [[2, 4, 9],
+ [0, 2, 6]]
+
+ result = UpperBound(sorted_sequence, values)
+
+ result == [[1, 2, 4],
+ [0, 2, 5]]
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt
index 1bc3660479..01387b7527 100644
--- a/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt
@@ -2,10 +2,31 @@ op {
visibility: HIDDEN
graph_op_name: "WindowDataset"
in_arg {
- name: "window_size"
+ name: "size"
description: <<END
A scalar representing the number of elements to accumulate in a window.
END
}
+ in_arg {
+ name: "shift"
+ description: <<END
+A scalar representing the steps moving the sliding window forward in one
+iteration. It must be positive.
+END
+ }
+ in_arg {
+ name: "stride"
+ description: <<END
+A scalar representing the stride of the input elements of the sliding window.
+It must be positive.
+END
+ }
+ in_arg {
+ name: "drop_remainder"
+ description: <<END
+A scalar representing whether a window should be dropped in case its size is
+smaller than desired.
+END
+ }
summary: "A dataset that creates window datasets from the input dataset."
}
diff --git a/tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt b/tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt
new file mode 100644
index 0000000000..ca107abc6b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "Xdivy"
+ summary: "Returns 0 if x == 0, and x / y otherwise, elementwise."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt b/tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt
new file mode 100644
index 0000000000..da625f7836
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "Xlogy"
+ summary: "Returns 0 if x == 0, and x * log(y) otherwise, elementwise."
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
index 9552fc92e3..e395e333bf 100644
--- a/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "BatchToSpaceND"
endpoint {
- name: "manip.batch_to_space_nd"
+ name: "batch_to_space_nd"
}
endpoint {
- name: "batch_to_space_nd"
+ name: "manip.batch_to_space_nd"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt b/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt
index 71257c8855..598f23bde3 100644
--- a/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "GatherNd"
endpoint {
- name: "manip.gather_nd"
+ name: "gather_nd"
}
endpoint {
- name: "gather_nd"
+ name: "manip.gather_nd"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt
new file mode 100644
index 0000000000..e22d980424
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "PrintV2"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
index c469665b66..b3d596de7a 100644
--- a/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "Reshape"
endpoint {
- name: "manip.reshape"
+ name: "reshape"
}
endpoint {
- name: "reshape"
+ name: "manip.reshape"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
index 77f595927b..51478b7c34 100644
--- a/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "ReverseV2"
endpoint {
- name: "manip.reverse"
+ name: "reverse"
}
endpoint {
- name: "reverse"
+ name: "manip.reverse"
deprecated: true
}
endpoint {
diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt
index a65a19b542..85888da45a 100644
--- a/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "ScatterNd"
endpoint {
- name: "manip.scatter_nd"
+ name: "scatter_nd"
}
endpoint {
- name: "scatter_nd"
+ name: "manip.scatter_nd"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt b/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt
index af323a6cf3..146b97f444 100644
--- a/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "SpaceToBatchND"
endpoint {
- name: "manip.space_to_batch_nd"
+ name: "space_to_batch_nd"
}
endpoint {
- name: "space_to_batch_nd"
+ name: "manip.space_to_batch_nd"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt
new file mode 100644
index 0000000000..8f0b1db45d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StringFormat"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt
index 01c02e1f70..df012414e3 100644
--- a/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt
@@ -1,6 +1,4 @@
op {
graph_op_name: "StringLength"
- endpoint {
- name: "strings.length"
- }
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt b/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt
index c34061c941..1d8695f1fd 100644
--- a/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "Tile"
endpoint {
- name: "manip.tile"
+ name: "tile"
}
endpoint {
- name: "tile"
+ name: "manip.tile"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnicodeScript.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnicodeScript.pbtxt
new file mode 100644
index 0000000000..a884a46143
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_UnicodeScript.pbtxt
@@ -0,0 +1,6 @@
+op {
+ graph_op_name: "UnicodeScript"
+ endpoint {
+ name: "strings.unicode_script"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt b/tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt
new file mode 100644
index 0000000000..984442ba2b
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt
@@ -0,0 +1,6 @@
+op {
+ graph_op_name: "Xdivy"
+ endpoint {
+ name: "math.xdivy"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt b/tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt
new file mode 100644
index 0000000000..b4a5299256
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt
@@ -0,0 +1,6 @@
+op {
+ graph_op_name: "Xlogy"
+ endpoint {
+ name: "math.xlogy"
+ }
+}
diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc
index 84c6285bbe..3843ea9e60 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/bfc_allocator.cc
@@ -31,7 +31,7 @@ namespace tensorflow {
BFCAllocator::BFCAllocator(SubAllocator* sub_allocator, size_t total_memory,
bool allow_growth, const string& name)
- : suballocator_(sub_allocator),
+ : sub_allocator_(sub_allocator),
name_(name),
free_chunks_list_(kInvalidChunkHandle),
next_allocation_id_(1) {
@@ -72,7 +72,7 @@ BFCAllocator::~BFCAllocator() {
VLOG(2) << "Number of regions allocated: "
<< region_manager_.regions().size();
for (const auto& region : region_manager_.regions()) {
- suballocator_->Free(region.ptr(), region.memory_size());
+ sub_allocator_->Free(region.ptr(), region.memory_size());
}
for (BinNum b = 0; b < kNumBins; b++) {
@@ -108,7 +108,7 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) {
// Try allocating.
size_t bytes = std::min(curr_region_allocation_bytes_, available_bytes);
- void* mem_addr = suballocator_->Alloc(alignment, bytes);
+ void* mem_addr = sub_allocator_->Alloc(alignment, bytes);
if (mem_addr == nullptr && !started_backpedal_) {
// Only backpedal once.
started_backpedal_ = true;
@@ -119,7 +119,7 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) {
while (mem_addr == nullptr) {
bytes = RoundedBytes(bytes * kBackpedalFactor);
if (bytes < rounded_bytes) break;
- mem_addr = suballocator_->Alloc(alignment, bytes);
+ mem_addr = sub_allocator_->Alloc(alignment, bytes);
}
}
@@ -158,10 +158,6 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) {
// Insert the chunk into the right bin.
InsertFreeChunkIntoBin(h);
- // Invoke visitors on newly allocated region.
- for (const auto& visitor : region_visitors_) {
- visitor(mem_addr, bytes);
- }
return true;
}
@@ -490,15 +486,6 @@ void BFCAllocator::FreeAndMaybeCoalesce(BFCAllocator::ChunkHandle h) {
InsertFreeChunkIntoBin(coalesced_chunk);
}
-void BFCAllocator::AddAllocVisitor(Visitor visitor) {
- VLOG(1) << "AddVisitor";
- mutex_lock l(lock_);
- region_visitors_.push_back(visitor);
- for (const auto& region : region_manager_.regions()) {
- visitor(region.ptr(), region.memory_size());
- }
-}
-
bool BFCAllocator::TracksAllocationSizes() { return true; }
size_t BFCAllocator::RequestedSize(const void* ptr) {
diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h
index 20e1dab1d5..2d74bf2b28 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.h
+++ b/tensorflow/core/common_runtime/bfc_allocator.h
@@ -23,7 +23,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/allocator_retry.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/macros.h"
@@ -42,7 +42,7 @@ namespace tensorflow {
// coalescing. One assumption we make is that the process using this
// allocator owns pretty much all of the memory, and that nearly
// all requests to allocate memory go through this interface.
-class BFCAllocator : public VisitableAllocator {
+class BFCAllocator : public Allocator {
public:
// Takes ownership of sub_allocator.
BFCAllocator(SubAllocator* sub_allocator, size_t total_memory,
@@ -55,11 +55,6 @@ class BFCAllocator : public VisitableAllocator {
const AllocationAttributes& allocation_attr) override;
void DeallocateRaw(void* ptr) override;
- void AddAllocVisitor(Visitor visitor) override;
-
- // Does nothing, because memory is never freed.
- void AddFreeVisitor(Visitor visitor) override {}
-
bool TracksAllocationSizes() override;
size_t RequestedSize(const void* ptr) override;
@@ -309,7 +304,7 @@ class BFCAllocator : public VisitableAllocator {
};
// Returns 'bytes' rounded up to the next highest kMinAllocationSize.
- size_t RoundedBytes(size_t bytes);
+ static size_t RoundedBytes(size_t bytes);
// Try to add a new memory region that can satisfy an allocation of
// 'rounded_bytes' bytes. Returns true on success and false on
@@ -423,7 +418,7 @@ class BFCAllocator : public VisitableAllocator {
// of the available memory.
bool started_backpedal_ = false;
- std::unique_ptr<SubAllocator> suballocator_;
+ std::unique_ptr<SubAllocator> sub_allocator_;
string name_;
// Structures mutable after construction
@@ -435,9 +430,6 @@ class BFCAllocator : public VisitableAllocator {
// Pointer to head of linked list of free Chunks
ChunkHandle free_chunks_list_ GUARDED_BY(lock_);
- // Called once on each region, ASAP.
- std::vector<Visitor> region_visitors_ GUARDED_BY(lock_);
-
// Counter containing the next unique identifier to assign to a
// newly-created chunk.
int64 next_allocation_id_ GUARDED_BY(lock_);
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index 97b6971c5b..419867ff58 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -61,6 +61,7 @@ bool ReadPartialShapesFromShapeMap(
shape_map,
std::vector<PartialTensorShape>* input_shapes) {
CHECK(shape_map != nullptr);
+ input_shapes->resize(n->num_inputs());
for (const Edge* in : n->in_edges()) {
// Don't need to check if incoming control edges have known shapes.
if (in->IsControlEdge()) continue;
@@ -71,7 +72,9 @@ bool ReadPartialShapesFromShapeMap(
}
const auto& known_shape = known_shape_iter->second;
CHECK_GT(known_shape.size(), in->src_output()) << known_shape_iter->first;
- input_shapes->push_back(known_shape[in->src_output()]);
+ DCHECK_GE(in->dst_input(), 0);
+ DCHECK_LT(in->dst_input(), input_shapes->size());
+ (*input_shapes)[in->dst_input()] = known_shape[in->src_output()];
}
return true;
}
@@ -467,19 +470,19 @@ bool ReplaceTensorWithConstant(
const ConstantFoldNameGenerator& generate_new_name) {
// Be conservative when replacing a tensor with a constant, when not
// running on CPU.
- // 1) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
+ // 1) Do not replace another constant.
+ // 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
// constraint, do not replace it.
- // 2) If the destination tensor is an int32 tensor, but has DEVICE_MEMORY
- // constraint, do not replace it.
- // 3) If the constant op created does not have a kernel implementation
- // for the device, do not use it.
- // 4) If the size of the constant in bytes is too large (>
+ // 3) If the size of the constant in bytes is too large (>
// max_constant_in_bytes), do not replace it. This prevents the size of the
// Graph from growing too large.
+ // 4) If the constant op created does not have a kernel implementation
+ // for the device, do not use it.
// TODO(keveman): Consider adding a new constant op that has a kernel
// implementation for all types, but with HostMemory constraint on it's
// output.
- // 5) Do not replace another constant.
+ // 5) If the constant op for the device has different output memory type
+ // from the original op output memory type, do not replace it.
if (tensor.first->IsConstant()) {
return false;
}
@@ -494,8 +497,7 @@ bool ReplaceTensorWithConstant(
return false;
}
bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32;
- if ((memory_type == HOST_MEMORY && !is_int32) ||
- (memory_type == DEVICE_MEMORY && is_int32)) {
+ if (memory_type == HOST_MEMORY && !is_int32) {
return false;
}
}
@@ -533,6 +535,23 @@ bool ReplaceTensorWithConstant(
if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
return false;
}
+ if (partition_device && device_type != DEVICE_CPU) {
+ MemoryType original_output_memory_type;
+ if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second,
+ &original_output_memory_type)
+ .ok()) {
+ return false;
+ }
+ MemoryType const_output_memory_type;
+ if (!MemoryTypeForOutput(device_type, graph, constant_node, 0,
+ &const_output_memory_type)
+ .ok()) {
+ return false;
+ }
+ if (original_output_memory_type != const_output_memory_type) {
+ return false;
+ }
+ }
for (auto edge : edges_to_remove) {
graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
graph->RemoveEdge(edge);
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index cf3d1f0b79..6e2eb66b94 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -61,26 +61,33 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator,
status_cb->Unref();
};
auto copier = std::bind(
- [dst, recv_dev_context, out_allocator, status_cb](
- StatusCallback wrapped_done_,
- // Begin unbound arguments
- const Tensor& from, Tensor* to) {
- if (!DMAHelper::CanUseDMA(&from)) {
- Status err = errors::InvalidArgument(
- "During Variant Host->Device Copy: "
- "non-DMA-copy attempted of tensor type: ",
- DataTypeString(from.dtype()));
- status_cb->UpdateStatus(err);
- return err;
- }
- if (status_cb->ok()) {
+ [dst, recv_dev_context, out_allocator, status_cb, cpu_allocator,
+ edge_name](StatusCallback wrapped_done_,
+ // Begin unbound arguments
+ const Tensor& from, Tensor* to) {
+ if (from.dtype() == DT_VARIANT) {
status_cb->Ref();
- *to = Tensor(out_allocator, from.dtype(), from.shape());
- recv_dev_context->CopyCPUTensorToDevice(&from, dst, to,
- wrapped_done_);
+ CopyHostToDevice(&from, cpu_allocator, out_allocator, edge_name,
+ dst, to, recv_dev_context, wrapped_done_);
return Status::OK();
} else {
- return status_cb->status();
+ if (!DMAHelper::CanUseDMA(&from)) {
+ Status err = errors::InvalidArgument(
+ "During Variant Host->Device Copy: "
+ "non-DMA-copy attempted of tensor type: ",
+ DataTypeString(from.dtype()));
+ status_cb->UpdateStatus(err);
+ return err;
+ }
+ if (status_cb->ok()) {
+ status_cb->Ref();
+ *to = Tensor(out_allocator, from.dtype(), from.shape());
+ recv_dev_context->CopyCPUTensorToDevice(&from, dst, to,
+ wrapped_done_);
+ return Status::OK();
+ } else {
+ return status_cb->status();
+ }
}
},
std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2);
@@ -119,26 +126,33 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
status_cb->Unref();
};
auto copier = std::bind(
- [edge_name, src, send_dev_context, out_allocator, status_cb](
- StatusCallback wrapped_done_,
- // Begin unbound arguments
- const Tensor& from, Tensor* to) {
- if (!DMAHelper::CanUseDMA(&from)) {
- Status err = errors::InvalidArgument(
- "During Variant Device->Host Copy: "
- "non-DMA-copy attempted of tensor type: ",
- DataTypeString(from.dtype()));
- status_cb->UpdateStatus(err);
- return err;
- }
- if (status_cb->ok()) {
+ [edge_name, src, send_dev_context, out_allocator, status_cb,
+ cpu_allocator](StatusCallback wrapped_done_,
+ // Begin unbound arguments
+ const Tensor& from, Tensor* to) {
+ if (from.dtype() == DT_VARIANT) {
status_cb->Ref();
- *to = Tensor(out_allocator, from.dtype(), from.shape());
- send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
- wrapped_done_);
+ CopyDeviceToHost(&from, cpu_allocator, out_allocator, edge_name,
+ src, to, send_dev_context, wrapped_done_);
return Status::OK();
} else {
- return status_cb->status();
+ if (!DMAHelper::CanUseDMA(&from)) {
+ Status err = errors::InvalidArgument(
+ "During Variant Device->Host Copy: "
+ "non-DMA-copy attempted of tensor type: ",
+ DataTypeString(from.dtype()));
+ status_cb->UpdateStatus(err);
+ return err;
+ }
+ if (status_cb->ok()) {
+ status_cb->Ref();
+ *to = Tensor(out_allocator, from.dtype(), from.shape());
+ send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
+ wrapped_done_);
+ return Status::OK();
+ } else {
+ return status_cb->status();
+ }
}
},
std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2);
@@ -347,7 +361,12 @@ namespace {
static Status WrappedTensorDeviceCopy(
const Tensor& from, Tensor* to,
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
- if (DMAHelper::CanUseDMA(&from)) {
+ if (from.dtype() == DT_VARIANT) {
+ // TODO(b/116349787): Implement support for nested variants.
+ return errors::Unimplemented(
+ "Support for copying nested variants to device has not yet been "
+ "implemented.");
+ } else if (DMAHelper::CanUseDMA(&from)) {
TF_RETURN_IF_ERROR(copy(from, to));
} else {
*to = from;
diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h
index 81d68e3be4..2ef1547cd9 100644
--- a/tensorflow/core/common_runtime/device.h
+++ b/tensorflow/core/common_runtime/device.h
@@ -101,11 +101,21 @@ class Device : public DeviceBase {
}
}
+ // If true, and tracing is enabled, the `tracing::ScopedAnnotation()` tracing
+ // mechanism will be used instead of `tracing::ScopedActivity()`. Some devices
+ // may override this method to use annotations, which enable child activities
+ // (such as GPU kernel launches) to be related to the OpKernel invocation.
+ virtual bool TraceUsingAnnotations() const { return false; }
+
// Blocks until all operations queued on the device at the time of
// the call have completed. Returns any error pending on the device
// at completion.
virtual Status Sync() = 0;
+ // Override this to return true for devices that require a Sync() call before
+ // session completion.
+ virtual bool RequiresSyncOnCompletion() const { return false; }
+
// Optionally modify the device's GraphDef before execution.
//
// This method should be considered experimental and is supplied to enable
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index af5d5b17e7..841181f8c3 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -363,7 +363,7 @@ Status DirectSession::MaybeInitializeExecutionState(
Status DirectSession::Create(const GraphDef& graph) {
TF_RETURN_IF_ERROR(init_error_);
if (graph.node_size() > 0) {
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
if (graph_created_) {
return errors::AlreadyExists(
"A Graph has already been created for this session.");
@@ -375,7 +375,7 @@ Status DirectSession::Create(const GraphDef& graph) {
Status DirectSession::Extend(const GraphDef& graph) {
TF_RETURN_IF_ERROR(CheckNotClosed());
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
return ExtendLocked(graph);
}
@@ -1172,7 +1172,7 @@ Status DirectSession::CreateExecutors(
int graph_def_version;
{
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
graph_def_version =
execution_state_->original_graph_def().versions().producer();
}
@@ -1400,7 +1400,7 @@ Status DirectSession::CreateGraphs(
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
RunStateArgs* run_state_args, DataTypeVector* input_types,
DataTypeVector* output_types, int64* collective_graph_key) {
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
std::unique_ptr<ClientGraph> client_graph;
std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index c2cf3c7fd7..4a6a921ea7 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -215,7 +215,7 @@ class DirectSession : public Session {
// if not already initialized.
Status MaybeInitializeExecutionState(const GraphDef& graph,
bool* out_already_initialized)
- EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
+ EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_);
// Retrieves an already existing set of executors to run 'inputs' and
// 'outputs', or creates and caches them for future use.
@@ -248,7 +248,7 @@ class DirectSession : public Session {
RunMetadata* run_metadata);
::tensorflow::Status ExtendLocked(const GraphDef& graph)
- EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
+ EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_);
::tensorflow::Status ResourceHandleToInputTensor(
const Tensor& resource_tensor, Tensor* retrieved_tensor);
@@ -289,7 +289,7 @@ class DirectSession : public Session {
}
::tensorflow::Status CheckGraphCreated(const char* method) {
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
if (!graph_created_) {
return errors::InvalidArgument(
"Session was not created with a graph before ", method, "!");
@@ -313,10 +313,8 @@ class DirectSession : public Session {
DeviceSet device_set_;
string session_handle_;
- bool graph_created_ GUARDED_BY(graph_def_lock_) = false;
-
- mutex graph_def_lock_;
- GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
+ mutex graph_state_lock_;
+ bool graph_created_ GUARDED_BY(graph_state_lock_) = false;
// The thread-pools to use for running ops, with a bool indicating if the pool
// is owned.
@@ -367,11 +365,11 @@ class DirectSession : public Session {
// nodes can not be moved to a different device. Maps node names to
// device names.
std::unordered_map<string, string> stateful_placements_
- GUARDED_BY(graph_def_lock_);
+ GUARDED_BY(graph_state_lock_);
// Execution_state; used when placing the entire graph.
std::unique_ptr<GraphExecutionState> execution_state_
- GUARDED_BY(graph_def_lock_);
+ GUARDED_BY(graph_state_lock_);
// The function library, before any rewrites or optimizations have been
// performed. In particular, CreateGraphs() may need to modify the function
@@ -386,7 +384,7 @@ class DirectSession : public Session {
std::atomic<int64> edge_name_counter_ = {0};
std::atomic<int64> handle_name_counter_ = {0};
- // For generating step ids that are unique across all sessions.
+ // For generating step ids that are unique across this sessions.
static std::atomic_int_fast64_t step_id_counter_;
// Global timeout for all blocking operations in this session.
@@ -395,8 +393,6 @@ class DirectSession : public Session {
// Manages all the cost models for the graphs executed in this session.
CostModelManager cost_model_manager_;
- Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr;
-
// For testing collective graph key generation.
mutex collective_graph_key_lock_;
int64 collective_graph_key_ GUARDED_BY(collective_graph_key_lock_) = -1;
diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
index 0b096a14a3..2c63b8704e 100644
--- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
@@ -77,6 +77,9 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
options.config.mutable_graph_options()
->mutable_rewrite_options()
->set_min_graph_nodes(-1);
+ options.config.mutable_graph_options()
+ ->mutable_rewrite_options()
+ ->set_pin_to_host_optimization(RewriterConfig::OFF);
std::unique_ptr<Session> session(NewSession(options));
TF_ASSERT_OK(session->Create(def));
std::vector<std::pair<string, Tensor>> inputs;
@@ -105,7 +108,7 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
EXPECT_EQ(2, shape.dim(0).size());
EXPECT_EQ(1, shape.dim(1).size());
if (node->name() == y->name()) {
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
// if MKL is used, it goes through various additional
// graph rewrite pass. In TF, everytime a graph pass
// happens, "constant" nodes are allocated
@@ -114,16 +117,16 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
// which increments the value of AllocationId.
// Thus AllocationId becomes more than TF if MKL
// is used. Now IDs for MKL are 8 more than TF.
- EXPECT_EQ(29, cm->AllocationId(node, 0));
-#else
EXPECT_EQ(21, cm->AllocationId(node, 0));
-#endif
- } else {
-#ifdef INTEL_MKL
- EXPECT_EQ(30, cm->AllocationId(node, 0));
#else
+ EXPECT_EQ(13, cm->AllocationId(node, 0));
+#endif // INTEL_MKL && ENABLE_MKL
+ } else {
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
EXPECT_EQ(22, cm->AllocationId(node, 0));
-#endif
+#else
+ EXPECT_EQ(14, cm->AllocationId(node, 0));
+#endif // INTEL_MKL && ENABLE_MKL
}
}
EXPECT_LE(0, cm->MaxExecutionTime(node));
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 263467a5b6..18420b60fd 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -32,6 +32,18 @@ bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
return default_val;
}
+std::unique_ptr<thread::ThreadPool> EagerThreadPool(
+ const SessionOptions& opts) {
+ SessionOptions opts_copy(opts);
+ if (opts_copy.config.inter_op_parallelism_threads() == 0) {
+ // Eager defaults to a single thread when no threads are specified.
+ opts_copy.config.set_inter_op_parallelism_threads(1);
+ }
+
+ return std::unique_ptr<thread::ThreadPool>(
+ NewThreadPoolFromSessionOptions(opts_copy));
+}
+
} // namespace
EagerContext::EagerContext(const SessionOptions& opts,
@@ -49,7 +61,7 @@ EagerContext::EagerContext(const SessionOptions& opts,
: policy_(default_policy),
devices_(device_mgr->ListDevices()),
rendezvous_(rendezvous),
- thread_pool_(NewThreadPoolFromSessionOptions(opts)),
+ thread_pool_(EagerThreadPool(opts)),
pflr_(new ProcessFunctionLibraryRuntime(
device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, {},
thread_pool_.get())),
@@ -67,7 +79,7 @@ EagerContext::EagerContext(const SessionOptions& opts,
}
InitDeviceMapAndAsync();
runner_ = [this](std::function<void()> closure) {
- this->thread_pool_->Schedule(closure);
+ this->thread_pool_->Schedule(std::move(closure));
};
}
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 1da1326a9a..1bc63616d0 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -251,26 +251,6 @@ Status EagerLocalExecute(EagerOperation* op,
EagerContext* ctx = op->EagerContext();
auto status = ctx->GetStatus();
if (!status.ok()) return status;
- // Ensure all resource-touching ops run in the device the resource is,
- // regardless of anything else that has been specified. This is identical to
- // the graph mode behavior.
- for (int i = 0; i < op->Inputs().size(); ++i) {
- Device* input_op_device = nullptr;
- status = op->Inputs()[i]->OpDevice(&input_op_device);
- if (!status.ok()) return status;
- VLOG(2) << "for op " << op->Name() << " input " << i << " "
- << DataTypeString(op->Inputs()[i]->dtype) << " "
- << (input_op_device == nullptr ? "cpu" : input_op_device->name())
- << " " << (op->Device() == nullptr ? "cpu" : op->Device()->name());
- if (op->Inputs()[i]->dtype == DT_RESOURCE &&
- (input_op_device != op->Device() || input_op_device == nullptr)) {
- Device* d = input_op_device == nullptr ? ctx->HostCPU() : input_op_device;
- VLOG(1) << "Changing device of operation " << op->Name() << " to "
- << d->name() << " because input #" << i
- << " is a resource in this device.";
- op->SetDevice(d);
- }
- }
Device* device = op->Device();
Fprint128 cache_key = op->MutableAttrs()->CacheKey(
@@ -604,6 +584,27 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
Status EagerExecute(EagerOperation* op,
gtl::InlinedVector<TensorHandle*, 2>* retvals,
int* num_retvals) {
+ // Ensure all resource-touching ops run in the device the resource is,
+ // regardless of anything else that has been specified. This is identical to
+ // the graph mode behavior.
+ EagerContext* ctx = op->EagerContext();
+ for (int i = 0; i < op->Inputs().size(); ++i) {
+ Device* input_op_device = nullptr;
+ auto status = op->Inputs()[i]->OpDevice(&input_op_device);
+ if (!status.ok()) return status;
+ VLOG(2) << "for op " << op->Name() << " input " << i << " "
+ << DataTypeString(op->Inputs()[i]->dtype) << " "
+ << (input_op_device == nullptr ? "cpu" : input_op_device->name())
+ << " " << (op->Device() == nullptr ? "cpu" : op->Device()->name());
+ if (op->Inputs()[i]->dtype == DT_RESOURCE &&
+ (input_op_device != op->Device() || input_op_device == nullptr)) {
+ Device* d = input_op_device == nullptr ? ctx->HostCPU() : input_op_device;
+ VLOG(1) << "Changing device of operation " << op->Name() << " to "
+ << d->name() << " because input #" << i
+ << " is a resource in this device.";
+ op->SetDevice(d);
+ }
+ }
bool op_is_local = IsLocal(op->EagerContext(), op->Device());
if (op_is_local) {
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index b912f7d37b..d58724cbfa 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -125,7 +125,6 @@ Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
Status TensorHandle::NumDims(int* num_dims) {
if (IsRemote()) {
TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
- CHECK(remote_shape_ != nullptr);
*num_dims = remote_shape_->dims();
} else {
TF_RETURN_IF_ERROR(WaitReady());
@@ -153,6 +152,21 @@ Status TensorHandle::Dim(int dim_index, int64* dim) {
return Status::OK();
}
+Status TensorHandle::NumElements(int64* num_elements) {
+ if (IsRemote()) {
+ TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
+ *num_elements = remote_shape_->num_elements();
+ } else {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ DCHECK(num_elements != nullptr);
+
+ *num_elements = tensor_.NumElements();
+ }
+
+ return Status::OK();
+}
+
Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) {
if (!IsRemote()) {
return errors::FailedPrecondition(
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
index 1bc9c6531a..e55f1a0338 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -113,6 +113,7 @@ class TensorHandle : public core::RefCounted {
Status NumDims(int* num_dims);
Status Dim(int dim_index, int64* dim);
+ Status NumElements(int64* num_elements);
// Return the op_id and output num if the handle refers to a remote tensor.
Status RemoteAddress(int64* op_id, int32* output_num);
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 84865397bc..2c48084cab 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -76,56 +76,47 @@ bool IsInitializationOp(const Node* node) {
namespace nodestats {
inline int64 NowInNsec() { return Env::Default()->NowNanos(); }
-void SetScheduled(NodeExecStatsWrapper* stats, int64 micros) {
+void SetScheduled(NodeExecStatsInterface* stats, int64 micros) {
if (!stats) return;
stats->SetScheduled(micros * EnvTime::kMicrosToNanos);
}
-void SetAllStart(NodeExecStatsWrapper* stats) {
+void SetAllStart(NodeExecStatsInterface* stats) {
if (!stats) return;
stats->RecordExecutorStarted();
}
-void SetOpStart(NodeExecStatsWrapper* stats) {
+void SetOpStart(NodeExecStatsInterface* stats) {
if (!stats) return;
stats->RecordComputeStarted();
}
-void SetOpEnd(NodeExecStatsWrapper* stats) {
+void SetOpEnd(NodeExecStatsInterface* stats) {
if (!stats) return;
stats->RecordComputeEnded();
}
-void SetAllEnd(NodeExecStatsWrapper* stats) {
+void SetAllEnd(NodeExecStatsInterface* stats) {
if (!stats) return;
stats->RecordExecutorEnded();
}
-void SetOutput(NodeExecStatsWrapper* stats, int slot, const Tensor* v) {
+void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) {
if (!stats) return;
stats->SetOutput(slot, v);
}
-void SetMemory(NodeExecStatsWrapper* stats, OpKernelContext* ctx) {
+void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) {
if (!stats) return;
stats->SetMemory(ctx);
}
-void SetReferencedTensors(NodeExecStatsWrapper* stats,
+void SetReferencedTensors(NodeExecStatsInterface* stats,
const TensorReferenceVector& tensors) {
if (!stats) return;
stats->SetReferencedTensors(tensors);
}
-// Sets the timeline_label field of *stats, using data from *node.
-// Returns true iff the node is a transfer node.
-bool SetTimelineLabel(const Node* node, NodeExecStatsWrapper* stats) {
- if (!stats) {
- return false;
- }
- return stats->SetTimelineLabel(node);
-}
-
} // namespace nodestats
class ExecutorImpl;
@@ -152,6 +143,8 @@ struct NodeItem {
bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr
bool is_merge : 1; // True iff IsMerge(node)
bool is_enter : 1; // True iff IsEnter(node)
+ bool is_constant_enter : 1; // True iff IsEnter(node) and
+ // node->GetAttr("is_constant") == true.
bool is_exit : 1; // True iff IsExit(node)
bool is_control_trigger : 1; // True iff IsControlTrigger(node)
bool is_sink : 1; // True iff IsSink(node)
@@ -635,6 +628,14 @@ Status ExecutorImpl::Initialize() {
item->kernel_is_async = (item->kernel->AsAsync() != nullptr);
item->is_merge = IsMerge(n);
item->is_enter = IsEnter(n);
+ if (item->is_enter) {
+ bool is_constant_enter;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter));
+ item->is_constant_enter = is_constant_enter;
+ } else {
+ item->is_constant_enter = false;
+ }
item->is_exit = IsExit(n);
item->is_control_trigger = IsControlTrigger(n);
item->is_sink = IsSink(n);
@@ -1237,6 +1238,9 @@ class ExecutorState {
// Step-local container.
ScopedStepContainer* step_container_;
StepStatsCollectorInterface* const stats_collector_;
+ const tracing::TraceCollector* const trace_collector_;
+ const tracing::EventCollector* const event_collector_;
+
// QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
// instead of a pointer? (avoids having to delete).
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
@@ -1245,6 +1249,7 @@ class ExecutorState {
CancellationManager* cancellation_manager_;
Executor::Args::Runner runner_;
bool sync_on_finish_;
+ const bool trace_using_annotations_;
// Owned.
@@ -1301,7 +1306,7 @@ class ExecutorState {
// After item->kernel computation is done, processes its outputs.
Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
- EntryVector* outputs, NodeExecStatsWrapper* stats);
+ EntryVector* outputs, NodeExecStatsInterface* stats);
// After processing the outputs, propagates the outputs to their dsts.
// Contents of *outputs are left in an indeterminate state after
@@ -1312,7 +1317,7 @@ class ExecutorState {
// "node" just finishes. Takes ownership of "stats". Returns true if
// execution has completed.
bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready,
- NodeExecStatsWrapper* stats,
+ NodeExecStatsInterface* stats,
TaggedNodeReadyQueue* inline_ready);
// Schedule all the expensive nodes in 'ready', and put all the inexpensive
@@ -1359,12 +1364,16 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
tensor_store_(args.tensor_store),
step_container_(args.step_container),
stats_collector_(args.stats_collector),
+ trace_collector_(tracing::GetTraceCollector()),
+ event_collector_(
+ tracing::GetEventCollector(tracing::EventCategory::kCompute)),
slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
call_frame_(args.call_frame),
impl_(impl),
cancellation_manager_(args.cancellation_manager),
runner_(args.runner),
sync_on_finish_(args.sync_on_finish),
+ trace_using_annotations_(impl->params_.device->TraceUsingAnnotations()),
num_outstanding_ops_(0) {
// We start the entire execution in iteration 0 of the root frame
// so let us create the root frame and the state for iteration 0.
@@ -1513,7 +1522,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
struct ExecutorState::AsyncState {
AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
const NodeItem* _item, Entry* _first_input,
- NodeExecStatsWrapper* _stats)
+ NodeExecStatsInterface* _stats)
: saved_inputs(*p.inputs),
saved_input_device_contexts(*p.input_device_contexts),
saved_input_alloc_attrs(*p.input_alloc_attrs),
@@ -1538,7 +1547,7 @@ struct ExecutorState::AsyncState {
const NodeItem* item;
Entry* first_input;
OpKernelContext ctx;
- NodeExecStatsWrapper* stats;
+ NodeExecStatsInterface* stats;
private:
OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
@@ -1550,6 +1559,32 @@ struct ExecutorState::AsyncState {
}
};
+// Returns true if `item` might be traced by the given trace and event
+// collectors. Returns false only if `item` definitely will not be traced.
+bool MightTrace(const NodeItem& item,
+ const tracing::TraceCollector* trace_collector,
+ const tracing::EventCollector* event_collector,
+ bool using_annotations) {
+ // Tracing will only be enabled if either `event_collector` is non null,
+ // or `trace_collector` is non-null and enabled for this particular kernel.
+ // Although `tracing::ScopedActivity`,
+ // `tracing::ScopedAnnotation`, and `tracing::ScopedRegion` check subsets of
+ // these properties internally in their constructors, the cost of passing the
+ // necessary arguments to them can be significant, so we avoid constructing
+ // them in the common case (when we know they will not be used).
+ if (event_collector != nullptr) {
+ return true;
+ }
+ if (trace_collector) {
+ if (using_annotations) {
+ return trace_collector->IsEnabledForAnnotations();
+ } else {
+ return trace_collector->IsEnabledForActivities(item.kernel_is_expensive);
+ }
+ }
+ return false;
+}
+
void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
const GraphView& gview = impl_->gview_;
TaggedNodeSeq ready;
@@ -1583,7 +1618,8 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
params.stats_collector = stats_collector_;
Status s;
- NodeExecStatsWrapper* stats = nullptr;
+ NodeExecStatsInterface* stats = nullptr;
+
EntryVector outputs;
bool completed = false;
inline_ready.push_back(tagged_node);
@@ -1613,7 +1649,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
if (stats_collector_ && !tagged_node.is_dead) {
// track allocations if and only if we are collecting statistics
params.track_allocations = true;
- stats = new NodeExecStatsWrapper(node->name());
+ stats = stats_collector_->CreateNodeExecStats(node);
nodestats::SetScheduled(stats, scheduled_nsec);
nodestats::SetAllStart(stats);
}
@@ -1671,7 +1707,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
auto done = [this, state]() {
Device* device = impl_->params_.device;
- NodeExecStatsWrapper* stats = state->stats; // Shorthand
+ NodeExecStatsInterface* stats = state->stats; // Shorthand
Entry* first_input = state->first_input; // Shorthand
nodestats::SetOpEnd(stats);
@@ -1720,7 +1756,32 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
// Synchronous computes.
OpKernelContext ctx(&params, item.num_outputs);
nodestats::SetOpStart(stats);
- device->Compute(CHECK_NOTNULL(op_kernel), &ctx);
+
+ if (TF_PREDICT_FALSE(MightTrace(item, trace_collector_,
+ event_collector_,
+ trace_using_annotations_))) {
+ const string& op_name = op_kernel->name();
+ tracing::ScopedRegion region(tracing::EventCategory::kCompute,
+ op_name);
+ if (trace_using_annotations_) {
+ // The OpKernel may create child activities (such as GPU kernel
+ // launches), so use a `ScopedAnnotation` to relate these activities
+ // in the trace.
+ tracing::ScopedAnnotation activity(op_name,
+ op_kernel->type_string());
+ device->Compute(op_kernel, &ctx);
+ } else {
+ // Use the cheaper `ScopedActivity` to trace just the OpKernel
+ // execution.
+ tracing::ScopedActivity activity(op_name, op_kernel->type_string(),
+ item.kernel_is_expensive);
+ device->Compute(op_kernel, &ctx);
+ }
+ } else {
+ // In the common case, avoid creating any tracing objects.
+ device->Compute(op_kernel, &ctx);
+ }
+
nodestats::SetOpEnd(stats);
s = ProcessOutputs(item, &ctx, &outputs, stats);
if (s.ok() && impl_->device_record_tensor_accesses_) {
@@ -1862,7 +1923,7 @@ Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input,
Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
EntryVector* outputs,
- NodeExecStatsWrapper* stats) {
+ NodeExecStatsInterface* stats) {
const Node* node = item.node;
DCHECK_EQ(0, outputs->size());
outputs->resize(item.num_outputs);
@@ -1997,15 +2058,12 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
is_frame_done = input_frame->DecrementOutstandingOpsLocked(
&impl_->gview_, input_iter, ready);
} else if (item->is_enter) {
- bool is_constant;
- const Status s = GetNodeAttr(node->attrs(), "is_constant", &is_constant);
- DCHECK(s.ok()) << s;
FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame);
output_iter = 0;
{
const NodeItem* item = impl_->gview_.node(node->id());
mutex_lock l(output_frame->mu);
- if (is_constant) {
+ if (item->is_constant_enter) {
// Propagate to all active iterations if this is a loop invariant.
output_frame->AddLoopInv(item, (*outputs)[0], ready);
} else {
@@ -2080,16 +2138,15 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
bool ExecutorState::NodeDone(const Status& s, const Node* node,
const TaggedNodeSeq& ready,
- NodeExecStatsWrapper* stats,
+ NodeExecStatsInterface* stats,
TaggedNodeReadyQueue* inline_ready) {
nodestats::SetAllEnd(stats);
- if (stats_collector_ != nullptr &&
- !nodestats::SetTimelineLabel(node, stats)) {
- // Only record non-transfer nodes.
- // Transfers 'stats' ownership to 'stats_collector_'.
- stats_collector_->Save(impl_->params_.device->name(), stats);
- } else if (stats) {
- delete stats;
+ if (stats) {
+ if (stats_collector_) {
+ stats->Done(impl_->params_.device->name());
+ } else {
+ delete stats;
+ }
}
bool abort_run = false;
@@ -2311,13 +2368,15 @@ void ExecutorState::Finish() {
auto done_cb = std::move(done_cb_);
auto runner = std::move(runner_);
mu_.unlock();
- if (sync_on_finish_ && status.ok()) {
+ Device* device = impl_->params_.device;
+ if ((sync_on_finish_ && status.ok()) || device->RequiresSyncOnCompletion()) {
// Block until the device has finished all queued operations. For
// devices like GPUs that continue to execute Ops after their Compute
// methods have completed, this ensures that control is not returned to
// the user until the step (and its side-effects) has actually completed.
- status = impl_->params_.device->Sync();
+ status.Update(device->Sync());
}
+
delete this;
CHECK(done_cb != nullptr);
runner([=]() { done_cb(status); });
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
index 6cd4fd22ea..34bf73972f 100644
--- a/tensorflow/core/common_runtime/executor.h
+++ b/tensorflow/core/common_runtime/executor.h
@@ -97,12 +97,6 @@ class Executor {
typedef std::function<void()> Closure;
typedef std::function<void(Closure)> Runner;
Runner runner = nullptr;
-
- // A callback that is invoked each time a node has finished executing.
- typedef std::function<Status(const string& node_name, const int output_slot,
- const Tensor* tensor, const bool is_ref,
- OpKernelContext* ctx)>
- NodeOutputsCallback;
};
typedef std::function<void(const Status&)> DoneCallback;
virtual void RunAsync(const Args& args, DoneCallback done) = 0;
diff --git a/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h b/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h
index 636cd43575..6bd29ef775 100644
--- a/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h
@@ -26,8 +26,12 @@ namespace tensorflow {
class CUDAHostAllocator : public SubAllocator {
public:
// Note: stream_exec cannot be null.
- explicit CUDAHostAllocator(se::StreamExecutor* stream_exec)
- : stream_exec_(stream_exec) {
+ explicit CUDAHostAllocator(se::StreamExecutor* stream_exec, int numa_node,
+ const std::vector<Visitor>& alloc_visitors,
+ const std::vector<Visitor>& free_visitors)
+ : SubAllocator(alloc_visitors, free_visitors),
+ stream_exec_(stream_exec),
+ numa_node_(numa_node) {
CHECK(stream_exec_ != nullptr);
}
~CUDAHostAllocator() override {}
@@ -39,19 +43,23 @@ class CUDAHostAllocator : public SubAllocator {
if (ptr == nullptr) {
LOG(WARNING) << "could not allocate pinned host memory of size: "
<< num_bytes;
+ return ptr;
}
+ VisitAlloc(ptr, numa_node_, num_bytes);
}
return ptr;
}
void Free(void* ptr, size_t num_bytes) override {
if (ptr != nullptr) {
+ VisitFree(ptr, numa_node_, num_bytes);
stream_exec_->HostMemoryDeallocate(ptr);
}
}
private:
se::StreamExecutor* stream_exec_; // not owned, non-null
+ const int numa_node_;
TF_DISALLOW_COPY_AND_ASSIGN(CUDAHostAllocator);
};
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
index 2d4c8d0201..42021e51f3 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
@@ -22,18 +22,48 @@ limitations under the License.
namespace tensorflow {
-GPUBFCAllocator::GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory,
- const string& name)
- : GPUBFCAllocator(cuda_gpu_id, total_memory, GPUOptions(), name) {}
+bool GPUBFCAllocator::GetAllowGrowthValue(const GPUOptions& gpu_options) {
+ const char* force_allow_growth_string =
+ std::getenv("TF_FORCE_GPU_ALLOW_GROWTH");
+ if (force_allow_growth_string == nullptr) {
+ return gpu_options.allow_growth();
+ }
+
+ if (strcmp("false", force_allow_growth_string) == 0) {
+ if (gpu_options.allow_growth()) {
+ LOG(WARNING)
+ << "Overriding allow_growth setting because the"
+ << " TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original"
+ << " config value was " << gpu_options.allow_growth() << ".";
+ }
+ return false;
+ } else if (strcmp("true", force_allow_growth_string) == 0) {
+ if (!gpu_options.allow_growth()) {
+ LOG(WARNING)
+ << "Overriding allow_growth setting because the"
+ << " TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original"
+ << " config value was " << gpu_options.allow_growth() << ".";
+ }
+ return true;
+ }
+
+ LOG(ERROR)
+ << "The TF_FORCE_GPU_ALLOW_GROWTH environment variable is set but could"
+ << " not be parsed: \"" << force_allow_growth_string << "\". Valid"
+ << " values are \"true\" or \"false\". Using original config value"
+ << " of " << gpu_options.allow_growth() << ".";
+ return gpu_options.allow_growth();
+}
+
+GPUBFCAllocator::GPUBFCAllocator(GPUMemAllocator* sub_allocator,
+ size_t total_memory, const string& name)
+ : GPUBFCAllocator(sub_allocator, total_memory, GPUOptions(), name) {}
-GPUBFCAllocator::GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory,
+GPUBFCAllocator::GPUBFCAllocator(GPUMemAllocator* sub_allocator,
+ size_t total_memory,
const GPUOptions& gpu_options,
const string& name)
- : BFCAllocator(
- new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(),
- gpu_options.per_process_gpu_memory_fraction() > 1.0 ||
- gpu_options.experimental().use_unified_memory()),
- total_memory, gpu_options.allow_growth(), name) {}
+ : BFCAllocator(sub_allocator, total_memory,
+ GPUBFCAllocator::GetAllowGrowthValue(gpu_options), name) {}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
index f1cc2eace1..d4c9cee89a 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
@@ -31,28 +31,20 @@ limitations under the License.
namespace tensorflow {
-// A GPU memory allocator that implements a 'best-fit with coalescing'
-// algorithm.
-class GPUBFCAllocator : public BFCAllocator {
- public:
- // 'cuda_gpu_id' refers to the ID of the GPU device within
- // the process and must reference a valid ID in the process.
- GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory,
- const string& name);
- GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory,
- const GPUOptions& gpu_options, const string& name);
- virtual ~GPUBFCAllocator() {}
-
- TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator);
-};
-
// Suballocator for GPU memory.
class GPUMemAllocator : public SubAllocator {
public:
+ // 'platform_gpu_id' refers to the ID of the GPU device within
+ // the process and must reference a valid ID in the process.
// Note: stream_exec cannot be null.
explicit GPUMemAllocator(se::StreamExecutor* stream_exec,
- bool use_unified_memory)
- : stream_exec_(stream_exec), use_unified_memory_(use_unified_memory) {
+ PlatformGpuId gpu_id, bool use_unified_memory,
+ const std::vector<Visitor>& alloc_visitors,
+ const std::vector<Visitor>& free_visitors)
+ : SubAllocator(alloc_visitors, free_visitors),
+ stream_exec_(stream_exec),
+ gpu_id_(gpu_id),
+ use_unified_memory_(use_unified_memory) {
CHECK(stream_exec_ != nullptr);
}
~GPUMemAllocator() override {}
@@ -65,12 +57,14 @@ class GPUMemAllocator : public SubAllocator {
} else {
ptr = stream_exec_->AllocateArray<char>(num_bytes).opaque();
}
+ VisitAlloc(ptr, gpu_id_.value(), num_bytes);
}
return ptr;
}
void Free(void* ptr, size_t num_bytes) override {
if (ptr != nullptr) {
+ VisitFree(ptr, gpu_id_.value(), num_bytes);
if (use_unified_memory_) {
stream_exec_->UnifiedMemoryDeallocate(ptr);
} else {
@@ -82,11 +76,28 @@ class GPUMemAllocator : public SubAllocator {
private:
se::StreamExecutor* stream_exec_; // not owned, non-null
+ const PlatformGpuId gpu_id_;
const bool use_unified_memory_ = false;
TF_DISALLOW_COPY_AND_ASSIGN(GPUMemAllocator);
};
+// A GPU memory allocator that implements a 'best-fit with coalescing'
+// algorithm.
+class GPUBFCAllocator : public BFCAllocator {
+ public:
+ GPUBFCAllocator(GPUMemAllocator* sub_allocator, size_t total_memory,
+ const string& name);
+ GPUBFCAllocator(GPUMemAllocator* sub_allocator, size_t total_memory,
+ const GPUOptions& gpu_options, const string& name);
+ ~GPUBFCAllocator() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator);
+
+ private:
+ static bool GetAllowGrowthValue(const GPUOptions& gpu_options);
+};
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
index 67caeb3495..60e82ed13b 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -46,7 +47,11 @@ static void CheckStats(Allocator* a, int64 num_allocs, int64 bytes_in_use,
}
TEST(GPUBFCAllocatorTest, NoDups) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
CheckStats(&a, 0, 0, 0, 0);
// Allocate a lot of raw pointers
@@ -75,7 +80,11 @@ TEST(GPUBFCAllocatorTest, NoDups) {
}
TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
// Allocate 256 raw pointers of sizes between 100 bytes and about
// a meg
random::PhiloxRandom philox(123, 17);
@@ -133,7 +142,11 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) {
}
TEST(GPUBFCAllocatorTest, ExerciseCoalescing) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
CheckStats(&a, 0, 0, 0, 0);
float* first_ptr = a.Allocate<float>(1024);
@@ -168,18 +181,30 @@ TEST(GPUBFCAllocatorTest, ExerciseCoalescing) {
}
TEST(GPUBFCAllocatorTest, AllocateZeroBufSize) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
float* ptr = a.Allocate<float>(0);
EXPECT_EQ(nullptr, ptr);
}
TEST(GPUBFCAllocatorTest, TracksSizes) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
EXPECT_EQ(true, a.TracksAllocationSizes());
}
TEST(GPUBFCAllocatorTest, AllocatedVsRequested) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
float* t1 = a.Allocate<float>(1);
EXPECT_EQ(4, a.RequestedSize(t1));
EXPECT_EQ(256, a.AllocatedSize(t1));
@@ -187,8 +212,12 @@ TEST(GPUBFCAllocatorTest, AllocatedVsRequested) {
}
TEST(GPUBFCAllocatorTest, TestCustomMemoryLimit) {
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
// Configure a 1MiB byte limit
- GPUBFCAllocator a(CudaGpuId(0), 1 << 20, "GPU_0_bfc");
+ GPUBFCAllocator a(sub_allocator, 1 << 20, "GPU_0_bfc");
float* first_ptr = a.Allocate<float>(1 << 6);
float* second_ptr = a.Allocate<float>(1 << 20);
@@ -203,7 +232,11 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocationsWithGrowth) {
options.set_allow_growth(true);
// Max of 2GiB, but starts out small.
- GPUBFCAllocator a(CudaGpuId(0), 1LL << 31, options, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1LL << 31, "GPU_0_bfc");
// Allocate 10 raw pointers of sizes between 100 bytes and about
// 64 megs.
@@ -264,8 +297,15 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocationsWithGrowth) {
}
TEST(GPUBFCAllocatorTest, DISABLED_AllocatorReceivesZeroMemory) {
- GPUBFCAllocator a(CudaGpuId(0), 1UL << 60, "GPU_0_bfc");
- GPUBFCAllocator b(CudaGpuId(0), 1UL << 60, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1UL << 60, "GPU_0_bfc");
+ sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator b(sub_allocator, 1UL << 60, "GPU_0_bfc");
void* amem = a.AllocateRaw(1, 1);
void* bmem = b.AllocateRaw(1, 1 << 30);
a.DeallocateRaw(amem);
@@ -273,7 +313,11 @@ TEST(GPUBFCAllocatorTest, DISABLED_AllocatorReceivesZeroMemory) {
}
static void BM_Allocation(int iters) {
- GPUBFCAllocator a(CudaGpuId(0), 1uLL << 33, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1uLL << 33, "GPU_0_bfc");
// Exercise a few different allocation sizes
std::vector<size_t> sizes = {256, 4096, 16384, 524288,
512, 1048576, 10485760, 104857600,
@@ -289,7 +333,11 @@ static void BM_Allocation(int iters) {
BENCHMARK(BM_Allocation);
static void BM_AllocationThreaded(int iters, int num_threads) {
- GPUBFCAllocator a(CudaGpuId(0), 1uLL << 33, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1uLL << 33, "GPU_0_bfc");
thread::ThreadPool pool(Env::Default(), "test", num_threads);
std::atomic_int_fast32_t count(iters);
mutex done_lock;
@@ -325,7 +373,11 @@ BENCHMARK(BM_AllocationThreaded)->Arg(1)->Arg(4)->Arg(16);
// A more complex benchmark that defers deallocation of an object for
// "delay" allocations.
static void BM_AllocationDelayed(int iters, int delay) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
// Exercise a few different allocation sizes
std::vector<int> sizes = {256, 4096, 16384, 4096, 512, 1024, 1024};
int size_index = 0;
@@ -358,12 +410,18 @@ BENCHMARK(BM_AllocationDelayed)->Arg(1)->Arg(10)->Arg(100)->Arg(1000);
class GPUBFCAllocatorPrivateMethodsTest : public ::testing::Test {
protected:
+ void SetUp() override { CHECK_EQ(unsetenv("TF_FORCE_GPU_ALLOW_GROWTH"), 0); }
+
// The following test methods are called from tests. The reason for this is
// that this class is a friend class to BFCAllocator, but tests are not, so
// only methods inside this class can access private members of BFCAllocator.
void TestBinDebugInfo() {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
std::vector<void*> initial_ptrs;
std::vector<size_t> initial_ptrs_allocated_sizes;
@@ -441,7 +499,11 @@ class GPUBFCAllocatorPrivateMethodsTest : public ::testing::Test {
}
void TestLog2FloorNonZeroSlow() {
- GPUBFCAllocator a(CudaGpuId(0), 1 /* total_memory */, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 /* total_memory */, "GPU_0_bfc");
EXPECT_EQ(-1, a.Log2FloorNonZeroSlow(0));
EXPECT_EQ(0, a.Log2FloorNonZeroSlow(1));
EXPECT_EQ(1, a.Log2FloorNonZeroSlow(2));
@@ -450,6 +512,56 @@ class GPUBFCAllocatorPrivateMethodsTest : public ::testing::Test {
EXPECT_EQ(10, a.Log2FloorNonZeroSlow(1024));
EXPECT_EQ(10, a.Log2FloorNonZeroSlow(1025));
}
+
+ void TestForceAllowGrowth() {
+ PlatformGpuId platform_gpu_id(0);
+ GPUOptions options;
+ // Unset flag value uses provided option.
+ unsetenv("TF_FORCE_GPU_ALLOW_GROWTH");
+ options.set_allow_growth(true);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator unset_flag_allocator(sub_allocator, 1LL << 31, options,
+ "GPU_0_bfc");
+ EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}),
+ unset_flag_allocator.curr_region_allocation_bytes_);
+
+ // Unparseable flag value uses provided option.
+ setenv("TF_FORCE_GPU_ALLOW_GROWTH", "unparseable", 1);
+ options.set_allow_growth(true);
+ sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator unparsable_flag_allocator(sub_allocator, 1LL << 31, options,
+ "GPU_1_bfc");
+ EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}),
+ unparsable_flag_allocator.curr_region_allocation_bytes_);
+
+ // Max of 2GiB total memory. Env variable set forces allow_growth, which
+ // does an initial allocation of 1MiB.
+ setenv("TF_FORCE_GPU_ALLOW_GROWTH", "true", 1);
+ options.set_allow_growth(false);
+ sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator force_allow_growth_allocator(sub_allocator, 1LL << 31,
+ options, "GPU_2_bfc");
+ EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}),
+ force_allow_growth_allocator.curr_region_allocation_bytes_);
+
+ // If env variable forces allow_growth disabled, all available memory is
+ // allocated.
+ setenv("TF_FORCE_GPU_ALLOW_GROWTH", "false", 1);
+ options.set_allow_growth(true);
+ sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator force_no_allow_growth_allocator(sub_allocator, 1LL << 31,
+ options, "GPU_3_bfc");
+ EXPECT_EQ(GPUBFCAllocator::RoundedBytes(1LL << 31),
+ force_no_allow_growth_allocator.curr_region_allocation_bytes_);
+ }
};
TEST_F(GPUBFCAllocatorPrivateMethodsTest, BinDebugInfo) { TestBinDebugInfo(); }
@@ -458,6 +570,10 @@ TEST_F(GPUBFCAllocatorPrivateMethodsTest, Log2FloorNonZeroSlow) {
TestLog2FloorNonZeroSlow();
}
+TEST_F(GPUBFCAllocatorPrivateMethodsTest, ForceAllowGrowth) {
+ TestForceAllowGrowth();
+}
+
} // namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
index 934a57a5fb..d85ca8892f 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
@@ -27,10 +27,11 @@ limitations under the License.
namespace tensorflow {
-GPUcudaMallocAllocator::GPUcudaMallocAllocator(VisitableAllocator* allocator,
- CudaGpuId cuda_gpu_id)
+GPUcudaMallocAllocator::GPUcudaMallocAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id)
: base_allocator_(allocator) {
- stream_exec_ = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ stream_exec_ =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
}
GPUcudaMallocAllocator::~GPUcudaMallocAllocator() { delete base_allocator_; }
@@ -60,14 +61,6 @@ void GPUcudaMallocAllocator::DeallocateRaw(void* ptr) {
#endif // GOOGLE_CUDA
}
-void GPUcudaMallocAllocator::AddAllocVisitor(Visitor visitor) {
- return base_allocator_->AddAllocVisitor(visitor);
-}
-
-void GPUcudaMallocAllocator::AddFreeVisitor(Visitor visitor) {
- return base_allocator_->AddFreeVisitor(visitor);
-}
-
bool GPUcudaMallocAllocator::TracksAllocationSizes() { return false; }
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
index 856fdc34b4..8df3724bc4 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
@@ -19,7 +19,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
@@ -29,20 +29,18 @@ namespace tensorflow {
// An allocator that wraps a GPU allocator and adds debugging
// functionality that verifies that users do not write outside their
// allocated memory.
-class GPUcudaMallocAllocator : public VisitableAllocator {
+class GPUcudaMallocAllocator : public Allocator {
public:
- explicit GPUcudaMallocAllocator(VisitableAllocator* allocator,
- CudaGpuId cuda_gpu_id);
+ explicit GPUcudaMallocAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id);
~GPUcudaMallocAllocator() override;
string Name() override { return "gpu_debug"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override;
void DeallocateRaw(void* ptr) override;
- void AddAllocVisitor(Visitor visitor) override;
- void AddFreeVisitor(Visitor visitor) override;
bool TracksAllocationSizes() override;
private:
- VisitableAllocator* base_allocator_ = nullptr; // owned
+ Allocator* base_allocator_ = nullptr; // owned
se::StreamExecutor* stream_exec_; // Not owned.
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
index e4c834b30d..989ddbe4af 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
@@ -73,10 +73,11 @@ void InitMask(se::StreamExecutor* exec, void* ptr, int64* mask) {
// -----------------------------------------------------------------------------
// GPUDebugAllocator
// -----------------------------------------------------------------------------
-GPUDebugAllocator::GPUDebugAllocator(VisitableAllocator* allocator,
- CudaGpuId cuda_gpu_id)
+GPUDebugAllocator::GPUDebugAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id)
: base_allocator_(allocator) {
- stream_exec_ = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ stream_exec_ =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
}
GPUDebugAllocator::~GPUDebugAllocator() { delete base_allocator_; }
@@ -111,14 +112,6 @@ void GPUDebugAllocator::DeallocateRaw(void* ptr) {
base_allocator_->DeallocateRaw(ptr);
}
-void GPUDebugAllocator::AddAllocVisitor(Visitor visitor) {
- return base_allocator_->AddAllocVisitor(visitor);
-}
-
-void GPUDebugAllocator::AddFreeVisitor(Visitor visitor) {
- return base_allocator_->AddFreeVisitor(visitor);
-}
-
bool GPUDebugAllocator::TracksAllocationSizes() { return true; }
size_t GPUDebugAllocator::RequestedSize(const void* ptr) {
@@ -158,10 +151,11 @@ bool GPUDebugAllocator::CheckFooter(void* ptr) {
// -----------------------------------------------------------------------------
// GPUNanResetAllocator
// -----------------------------------------------------------------------------
-GPUNanResetAllocator::GPUNanResetAllocator(VisitableAllocator* allocator,
- CudaGpuId cuda_gpu_id)
+GPUNanResetAllocator::GPUNanResetAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id)
: base_allocator_(allocator) {
- stream_exec_ = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ stream_exec_ =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
}
GPUNanResetAllocator::~GPUNanResetAllocator() { delete base_allocator_; }
@@ -200,14 +194,6 @@ void GPUNanResetAllocator::DeallocateRaw(void* ptr) {
base_allocator_->DeallocateRaw(ptr);
}
-void GPUNanResetAllocator::AddAllocVisitor(Visitor visitor) {
- return base_allocator_->AddAllocVisitor(visitor);
-}
-
-void GPUNanResetAllocator::AddFreeVisitor(Visitor visitor) {
- return base_allocator_->AddFreeVisitor(visitor);
-}
-
size_t GPUNanResetAllocator::RequestedSize(const void* ptr) {
return base_allocator_->RequestedSize(ptr);
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
index 0f9b72040c..17757a106c 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
@@ -21,7 +21,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
@@ -31,16 +31,14 @@ namespace tensorflow {
// An allocator that wraps a GPU allocator and adds debugging
// functionality that verifies that users do not write outside their
// allocated memory.
-class GPUDebugAllocator : public VisitableAllocator {
+class GPUDebugAllocator : public Allocator {
public:
- explicit GPUDebugAllocator(VisitableAllocator* allocator,
- CudaGpuId cuda_gpu_id);
+ explicit GPUDebugAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id);
~GPUDebugAllocator() override;
string Name() override { return "gpu_debug"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override;
void DeallocateRaw(void* ptr) override;
- void AddAllocVisitor(Visitor visitor) override;
- void AddFreeVisitor(Visitor visitor) override;
bool TracksAllocationSizes() override;
size_t RequestedSize(const void* ptr) override;
size_t AllocatedSize(const void* ptr) override;
@@ -53,7 +51,7 @@ class GPUDebugAllocator : public VisitableAllocator {
bool CheckFooter(void* ptr);
private:
- VisitableAllocator* base_allocator_ = nullptr; // owned
+ Allocator* base_allocator_ = nullptr; // owned
se::StreamExecutor* stream_exec_; // Not owned.
@@ -63,23 +61,21 @@ class GPUDebugAllocator : public VisitableAllocator {
// An allocator that wraps a GPU allocator and resets the memory on
// allocation and free to 'NaN', helping to identify cases where the
// user forgets to initialize the memory.
-class GPUNanResetAllocator : public VisitableAllocator {
+class GPUNanResetAllocator : public Allocator {
public:
- explicit GPUNanResetAllocator(VisitableAllocator* allocator,
- CudaGpuId cuda_gpu_id);
+ explicit GPUNanResetAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id);
~GPUNanResetAllocator() override;
string Name() override { return "gpu_nan_reset"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override;
void DeallocateRaw(void* ptr) override;
- void AddAllocVisitor(Visitor visitor) override;
- void AddFreeVisitor(Visitor visitor) override;
size_t RequestedSize(const void* ptr) override;
size_t AllocatedSize(const void* ptr) override;
void GetStats(AllocatorStats* stats) override;
void ClearStats() override;
private:
- VisitableAllocator* base_allocator_ = nullptr; // owned
+ Allocator* base_allocator_ = nullptr; // owned
se::StreamExecutor* stream_exec_; // Not owned.
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
index 236a0afa0b..aca08a7e33 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
@@ -34,10 +34,14 @@ namespace tensorflow {
namespace {
TEST(GPUDebugAllocatorTest, OverwriteDetection_None) {
- const CudaGpuId cuda_gpu_id(0);
- GPUDebugAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id);
- auto stream_exec = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ const PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id);
+ auto stream_exec =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
for (int s : {8}) {
std::vector<int64> cpu_array(s);
@@ -58,11 +62,14 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Header) {
for (int s : {8, 211}) {
EXPECT_DEATH(
{
- const CudaGpuId cuda_gpu_id(0);
- GPUDebugAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id);
+ const PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id);
auto stream_exec =
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
std::vector<int64> cpu_array(s);
memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64));
@@ -91,11 +98,14 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Footer) {
for (int s : {8, 22}) {
EXPECT_DEATH(
{
- const CudaGpuId cuda_gpu_id(0);
- GPUDebugAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id);
+ const PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id);
auto stream_exec =
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
std::vector<int64> cpu_array(s);
memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64));
@@ -121,10 +131,14 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Footer) {
}
TEST(GPUDebugAllocatorTest, ResetToNan) {
- const CudaGpuId cuda_gpu_id(0);
- GPUNanResetAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id);
- auto stream_exec = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ const PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUNanResetAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id);
+ auto stream_exec =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
std::vector<float> cpu_array(1024);
std::vector<float> cpu_array_result(1024);
@@ -161,13 +175,17 @@ TEST(GPUDebugAllocatorTest, ResetToNan) {
}
TEST(GPUDebugAllocatorTest, ResetToNanWithHeaderFooter) {
- const CudaGpuId cuda_gpu_id(0);
+ const PlatformGpuId platform_gpu_id(0);
// NaN reset must be the outer-most allocator.
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUNanResetAllocator a(
- new GPUDebugAllocator(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id),
- cuda_gpu_id);
- auto stream_exec = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ new GPUDebugAllocator(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id),
+ platform_gpu_id);
+ auto stream_exec =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
std::vector<float> cpu_array(1024);
std::vector<float> cpu_array_result(1024);
@@ -204,18 +222,24 @@ TEST(GPUDebugAllocatorTest, ResetToNanWithHeaderFooter) {
}
TEST(GPUDebugAllocatorTest, TracksSizes) {
- const CudaGpuId cuda_gpu_id(0);
- GPUDebugAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id);
+ const PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id);
EXPECT_EQ(true, a.TracksAllocationSizes());
}
TEST(GPUDebugAllocatorTest, AllocatedVsRequested) {
- const CudaGpuId cuda_gpu_id(0);
+ const PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUNanResetAllocator a(
- new GPUDebugAllocator(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id),
- cuda_gpu_id);
+ new GPUDebugAllocator(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id),
+ platform_gpu_id);
float* t1 = a.Allocate<float>(1);
EXPECT_EQ(4, a.RequestedSize(t1));
EXPECT_EQ(256, a.AllocatedSize(t1));
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 2763ac0d4a..d8ebdeff5d 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -41,7 +41,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/common_runtime/local_device.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -105,9 +104,9 @@ class EigenCudaStreamDevice : public ::Eigen::StreamInterface {
reinterpret_cast<unsigned int*>(scratch + Eigen::kCudaScratchSize);
stream_ = cuda_stream;
allocator_ = alloc;
- CudaGpuId cuda_gpu_id;
- TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
- device_prop_ = &Eigen::m_deviceProperties[cuda_gpu_id.value()];
+ PlatformGpuId platform_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
+ device_prop_ = &Eigen::m_deviceProperties[platform_gpu_id.value()];
}
const cudaStream_t& stream() const override { return *stream_; }
@@ -285,6 +284,38 @@ BaseGPUDevice::~BaseGPUDevice() {
for (auto ctx : device_contexts_) ctx->Unref();
}
+// This should be idempotent if already initialized.
+Status BaseGPUDevice::InitScratchBuffers() {
+ mutex_lock l(scratch_init_mutex_);
+ if (scratch_.size() < max_streams_) {
+ for (int i = 0; i < max_streams_; i++) {
+ DCHECK(streams_[i]);
+ if (scratch_.size() > i && scratch_[i]) continue;
+ size_t scratch_buffer_size =
+ Eigen::kCudaScratchSize + sizeof(unsigned int);
+ void* scratch_buffer = gpu_allocator_->AllocateRaw(
+ Allocator::kAllocatorAlignment, scratch_buffer_size);
+ if (scratch_buffer == nullptr) {
+ return errors::FailedPrecondition(
+ "Failed to allocate scratch buffer for device ",
+ tf_gpu_id_.value());
+ }
+ se::DeviceMemory<char> mem(
+ se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size));
+
+ bool ok = executor_->SynchronousMemZero(
+ &mem, Eigen::kCudaScratchSize + sizeof(unsigned int));
+ if (!ok) {
+ return errors::FailedPrecondition(
+ "Failed to memcopy into scratch buffer for device ",
+ tf_gpu_id_.value());
+ }
+ scratch_.push_back(static_cast<char*>(scratch_buffer));
+ }
+ }
+ return Status::OK();
+}
+
Status BaseGPUDevice::Init(const SessionOptions& options) {
auto executor_status = GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id_);
if (!executor_status.status().ok()) {
@@ -303,27 +334,6 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
for (int i = 0; i < max_streams_; i++) {
streams_.push_back(StreamGroupFactory::Global().GetOrCreate(
tf_gpu_id_, i, executor_, options.config.gpu_options()));
-
- size_t scratch_buffer_size = Eigen::kCudaScratchSize + sizeof(unsigned int);
- void* scratch_buffer = gpu_allocator_->AllocateRaw(
- Allocator::kAllocatorAlignment, scratch_buffer_size);
- if (scratch_buffer == nullptr) {
- return errors::FailedPrecondition(
- "Failed to allocate scratch buffer for device ", tf_gpu_id_.value());
- }
- scratch_.push_back(static_cast<char*>(scratch_buffer));
-
- se::DeviceMemory<char> mem(
- se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size));
-
- bool ok = executor_->SynchronousMemZero(
- &mem, Eigen::kCudaScratchSize + sizeof(unsigned int));
- if (!ok) {
- return errors::FailedPrecondition(
- "Failed to memcopy into scratch buffer for device ",
- tf_gpu_id_.value());
- }
-
device_contexts_.push_back(new GPUDeviceContext(
i, streams_.back()->compute, streams_.back()->host_to_device,
streams_.back()->device_to_host, streams_.back()->device_to_device));
@@ -332,9 +342,10 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
gpu_device_info_->stream = streams_[0]->compute;
gpu_device_info_->default_context = device_contexts_[0];
gpu_device_info_->event_mgr = em_.get();
- CudaGpuId cuda_gpu_id;
- TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id_, &cuda_gpu_id));
- gpu_device_info_->gpu_id = cuda_gpu_id.value();
+ PlatformGpuId platform_gpu_id;
+ TF_RETURN_IF_ERROR(
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_id_, &platform_gpu_id));
+ gpu_device_info_->gpu_id = platform_gpu_id.value();
set_tensorflow_gpu_device_info(gpu_device_info_);
// Whether and how the GPU device uses its own threadpool.
@@ -423,9 +434,6 @@ Status BaseGPUDevice::FillContextMap(const Graph* graph,
}
void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
- tracing::ScopedRegion region(tracing::EventCategory::kCompute,
- op_kernel->name());
-
// NOTE(tucker): We need to discriminate between Eigen GPU
// operations and all others. If an operation is Eigen
// implemented (or otherwise tries to launch a cuda kernel
@@ -439,8 +447,6 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
context->SetStatus(errors::Internal(
"Invalid synchronous 'Compute' on GPU for '_Recv' op"));
} else {
- tracing::ScopedAnnotation annotation(op_kernel->name(),
- op_kernel->type_string());
ComputeHelper(op_kernel, context);
}
}
@@ -690,9 +696,9 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice {
Eigen::GpuDevice device_;
};
-// Parse 'visible_device_list' into a list of CUDA GPU ids.
+// Parse 'visible_device_list' into a list of platform GPU ids.
Status ParseVisibleDeviceList(const string& visible_device_list,
- std::vector<CudaGpuId>* visible_gpu_order) {
+ std::vector<PlatformGpuId>* visible_gpu_order) {
visible_gpu_order->clear();
se::Platform* gpu_manager = GPUMachineManager();
@@ -707,26 +713,28 @@ Status ParseVisibleDeviceList(const string& visible_device_list,
} else {
const std::vector<string> order_str =
str_util::Split(visible_device_list, ',');
- for (const string& cuda_gpu_id_str : order_str) {
- int32 cuda_gpu_id;
- if (!strings::safe_strto32(cuda_gpu_id_str, &cuda_gpu_id)) {
+ for (const string& platform_gpu_id_str : order_str) {
+ int32 platform_gpu_id;
+ if (!strings::safe_strto32(platform_gpu_id_str, &platform_gpu_id)) {
return errors::InvalidArgument(
"Could not parse entry in 'visible_device_list': '",
- cuda_gpu_id_str, "'. visible_device_list = ", visible_device_list);
+ platform_gpu_id_str, "'. visible_device_list = ",
+ visible_device_list);
}
- if (cuda_gpu_id < 0 || cuda_gpu_id >= gpu_manager->VisibleDeviceCount()) {
+ if (platform_gpu_id < 0 ||
+ platform_gpu_id >= gpu_manager->VisibleDeviceCount()) {
return errors::InvalidArgument(
- "'visible_device_list' listed an invalid GPU id '", cuda_gpu_id,
+ "'visible_device_list' listed an invalid GPU id '", platform_gpu_id,
"' but visible device count is ",
gpu_manager->VisibleDeviceCount());
}
- visible_gpu_order->push_back(CudaGpuId(cuda_gpu_id));
+ visible_gpu_order->push_back(PlatformGpuId(platform_gpu_id));
}
}
// Validate no repeats.
- std::set<CudaGpuId> visible_device_set(visible_gpu_order->begin(),
- visible_gpu_order->end());
+ std::set<PlatformGpuId> visible_device_set(visible_gpu_order->begin(),
+ visible_gpu_order->end());
if (visible_device_set.size() != visible_gpu_order->size()) {
return errors::InvalidArgument(
"visible_device_list contained a duplicate entry: ",
@@ -737,8 +745,8 @@ Status ParseVisibleDeviceList(const string& visible_device_list,
Status VerifyVirtualDeviceSettings(
const size_t num_gpus_to_use, const GPUOptions& gpu_options,
- const std::vector<CudaGpuId>& visible_gpu_order,
- const std::vector<CudaGpuId>& valid_cuda_gpu_ids) {
+ const std::vector<PlatformGpuId>& visible_gpu_order,
+ const std::vector<PlatformGpuId>& valid_platform_gpu_ids) {
const auto& virtual_devices = gpu_options.experimental().virtual_devices();
CHECK(!virtual_devices.empty());
if (gpu_options.per_process_gpu_memory_fraction() > 0) {
@@ -760,11 +768,11 @@ Status VerifyVirtualDeviceSettings(
" #GPUs in visible_device_list: ", visible_gpu_order.size(),
" virtual_devices.size(): ", virtual_devices.size());
}
- if (valid_cuda_gpu_ids.size() != virtual_devices.size()) {
+ if (valid_platform_gpu_ids.size() != virtual_devices.size()) {
return errors::Unknown(
"The number of valid GPUs doesn't match the number of elements in "
"the virtual_devices list.",
- " #valid GPUs: ", valid_cuda_gpu_ids.size(),
+ " #valid GPUs: ", valid_platform_gpu_ids.size(),
" virtual_devices.size(): ", virtual_devices.size());
}
return Status::OK();
@@ -806,18 +814,18 @@ int64 MinSystemMemory(int64 available_memory) {
}
// Get the memory limit for the virtual device being created on GPU with
-// 'cuda_gpu_id', when that virtual device is the only virtual device being
+// 'platform_gpu_id', when that virtual device is the only virtual device being
// created on that GPU.
Status SingleVirtualDeviceMemoryLimit(const GPUOptions& gpu_options,
- CudaGpuId cuda_gpu_id,
+ PlatformGpuId platform_gpu_id,
int64* memory_limit) {
int64 total_memory = 0;
int64 available_memory = 0;
se::StreamExecutor* se =
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
if (!se->DeviceMemoryUsage(&available_memory, &total_memory)) {
return errors::Unknown("Failed to query available memory for GPU ",
- cuda_gpu_id.value());
+ platform_gpu_id.value());
}
int64 allocated_memory = 0;
@@ -867,10 +875,11 @@ PerOpGpuDevice* BaseGPUDevice::MakeGpuDevice() {
return new ConcretePerOpGpuDevice();
}
-void BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context,
- PerOpGpuDevice* device,
- DeviceContext* dc,
- Allocator* allocator) {
+Status BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context,
+ PerOpGpuDevice* device,
+ DeviceContext* dc,
+ Allocator* allocator) {
+ TF_RETURN_IF_ERROR(InitScratchBuffers());
if (dc) {
const GPUDeviceContext* gpu_dc = static_cast<GPUDeviceContext*>(dc);
const int stream_id = gpu_dc->stream_id();
@@ -881,6 +890,7 @@ void BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context,
} else {
ReinitializeDevice(context, device, 0, allocator);
}
+ return Status::OK();
}
Allocator* BaseGPUDevice::GetScopedAllocator(AllocatorAttributes attr,
@@ -916,8 +926,8 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
num_gpus_to_use = iter->second;
}
const auto& gpu_options = options.config.gpu_options();
- std::vector<CudaGpuId> visible_gpu_order;
- std::vector<CudaGpuId> valid_cuda_gpu_ids;
+ std::vector<PlatformGpuId> visible_gpu_order;
+ std::vector<PlatformGpuId> valid_platform_gpu_ids;
// If we aren't going to use any GPUs, don't initialize them.
// We don't want to call ParseVisibleDeviceList if num_gpus_to_use is 0,
// because it treats an empty gpu_options.visible_device_list as 'all GPUs are
@@ -926,12 +936,12 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
TF_RETURN_IF_ERROR(ParseVisibleDeviceList(gpu_options.visible_device_list(),
&visible_gpu_order));
TF_RETURN_IF_ERROR(
- GetValidDeviceIds(visible_gpu_order, &valid_cuda_gpu_ids));
+ GetValidDeviceIds(visible_gpu_order, &valid_platform_gpu_ids));
}
- if (num_gpus_to_use > valid_cuda_gpu_ids.size()) {
- num_gpus_to_use = valid_cuda_gpu_ids.size();
+ if (num_gpus_to_use > valid_platform_gpu_ids.size()) {
+ num_gpus_to_use = valid_platform_gpu_ids.size();
}
- if (!valid_cuda_gpu_ids.empty()) {
+ if (!valid_platform_gpu_ids.empty()) {
// Save the original device.
int original_device = 0;
cudaError_t err = cudaGetDevice(&original_device);
@@ -941,17 +951,18 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
}
// Force to implicitly initialize CUDA runtime on each valid GPU before
// CreateGPUDevice().
- for (CudaGpuId cuda_gpu_id : valid_cuda_gpu_ids) {
- err = cudaSetDevice(cuda_gpu_id.value());
+ for (PlatformGpuId platform_gpu_id : valid_platform_gpu_ids) {
+ err = cudaSetDevice(platform_gpu_id.value());
if (err != cudaSuccess) {
- return errors::Internal("cudaSetDevice() on GPU:", cuda_gpu_id.value(),
- " failed. Status: ", cudaGetErrorString(err));
+ return errors::Internal("cudaSetDevice() on GPU:",
+ platform_gpu_id.value(), " failed. Status: ",
+ cudaGetErrorString(err));
}
err = cudaFree(nullptr);
if (err != cudaSuccess) {
- return errors::Internal(
- "CUDA runtime implicit initialization on GPU:", cuda_gpu_id.value(),
- " failed. Status: ", cudaGetErrorString(err));
+ return errors::Internal("CUDA runtime implicit initialization on GPU:",
+ platform_gpu_id.value(), " failed. Status: ",
+ cudaGetErrorString(err));
}
}
// Reset to the original device.
@@ -977,10 +988,10 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
LOG(INFO) << line_buf;
for (int i = 0; i < visible_gpu_order.size(); ++i) {
line_buf = strings::StrCat(visible_gpu_order[i].value(), ": ");
- CudaGpuId cuda_id_i = visible_gpu_order[i];
+ PlatformGpuId gpu_id_i = visible_gpu_order[i];
for (int j = 0; j < visible_gpu_order.size(); ++j) {
- CudaGpuId cuda_id_j = visible_gpu_order[j];
- if (im.directed_links.find({cuda_id_i, cuda_id_j}) !=
+ PlatformGpuId gpu_id_j = visible_gpu_order[j];
+ if (im.directed_links.find({gpu_id_i, gpu_id_j}) !=
im.directed_links.end()) {
line_buf.append("Y ");
} else {
@@ -993,22 +1004,23 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
const auto& virtual_devices = gpu_options.experimental().virtual_devices();
if (!virtual_devices.empty()) {
- TF_RETURN_IF_ERROR(VerifyVirtualDeviceSettings(
- num_gpus_to_use, gpu_options, visible_gpu_order, valid_cuda_gpu_ids));
+ TF_RETURN_IF_ERROR(VerifyVirtualDeviceSettings(num_gpus_to_use, gpu_options,
+ visible_gpu_order,
+ valid_platform_gpu_ids));
// We've verified that num_gpus_to_use >= virtual_devices.size().
num_gpus_to_use = virtual_devices.size();
CHECK(gpu_options.visible_device_list().empty() ||
- valid_cuda_gpu_ids == visible_gpu_order);
+ valid_platform_gpu_ids == visible_gpu_order);
}
int next_tf_gpu_id = 0;
std::vector<int64> memory_limit_bytes;
for (int i = 0; i < num_gpus_to_use; ++i) {
- const CudaGpuId cuda_gpu_id = valid_cuda_gpu_ids[i];
+ const PlatformGpuId platform_gpu_id = valid_platform_gpu_ids[i];
if (virtual_devices.empty() ||
virtual_devices.Get(i).memory_limit_mb_size() == 0) {
int64 single_virtual_device_memory_limit = 0;
TF_RETURN_IF_ERROR(SingleVirtualDeviceMemoryLimit(
- gpu_options, cuda_gpu_id, &single_virtual_device_memory_limit));
+ gpu_options, platform_gpu_id, &single_virtual_device_memory_limit));
memory_limit_bytes.push_back(single_virtual_device_memory_limit);
} else {
const auto& memory_limit_mb = virtual_devices.Get(i).memory_limit_mb();
@@ -1021,7 +1033,7 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
TfGpuId tf_gpu_id(next_tf_gpu_id);
++next_tf_gpu_id;
TF_RETURN_IF_ERROR(
- GpuIdManager::InsertTfCudaGpuIdPair(tf_gpu_id, cuda_gpu_id));
+ GpuIdManager::InsertTfPlatformGpuIdPair(tf_gpu_id, platform_gpu_id));
}
}
const int num_tf_gpus = next_tf_gpu_id;
@@ -1046,7 +1058,7 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
return Status::OK();
}
-static string GetShortDeviceDescription(CudaGpuId cuda_gpu_id,
+static string GetShortDeviceDescription(PlatformGpuId platform_gpu_id,
const se::DeviceDescription& desc) {
int cc_major;
int cc_minor;
@@ -1055,9 +1067,8 @@ static string GetShortDeviceDescription(CudaGpuId cuda_gpu_id,
cc_minor = 0;
}
// LINT.IfChange
- return strings::StrCat("device: ", cuda_gpu_id.value(),
- ", name: ", desc.name(),
- ", pci bus id: ", desc.pci_bus_id(),
+ return strings::StrCat("device: ", platform_gpu_id.value(), ", name: ",
+ desc.name(), ", pci bus id: ", desc.pci_bus_id(),
", compute capability: ", cc_major, ".", cc_minor);
// LINT.ThenChange(//tensorflow/python/platform/test.py)
}
@@ -1072,12 +1083,13 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
const string device_name =
strings::StrCat(name_prefix, "/device:GPU:", tf_gpu_id.value());
GpuIdUtil::CheckValidTfGpuId(tf_gpu_id);
- CudaGpuId cuda_gpu_id;
- TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
+ PlatformGpuId platform_gpu_id;
+ TF_RETURN_IF_ERROR(
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
int numa_node = dev_locality.numa_node();
se::StreamExecutor* se =
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
const se::DeviceDescription& desc = se->GetDeviceDescription();
GPUProcessState* process_state = GPUProcessState::singleton();
Allocator* gpu_allocator = process_state->GetGPUAllocator(
@@ -1098,11 +1110,11 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
// TODO(laigd): report error if memory_limit doesn't match stats.bytes_limit.
BaseGPUDevice* gpu_device = CreateGPUDevice(
options, device_name, static_cast<Bytes>(stats.bytes_limit), dev_locality,
- tf_gpu_id, GetShortDeviceDescription(cuda_gpu_id, desc), gpu_allocator,
- ProcessState::singleton()->GetCPUAllocator(numa_node));
+ tf_gpu_id, GetShortDeviceDescription(platform_gpu_id, desc),
+ gpu_allocator, ProcessState::singleton()->GetCPUAllocator(numa_node));
LOG(INFO) << "Created TensorFlow device (" << device_name << " with "
<< (stats.bytes_limit >> 20) << " MB memory) -> physical GPU ("
- << GetShortDeviceDescription(cuda_gpu_id, desc) << ")";
+ << GetShortDeviceDescription(platform_gpu_id, desc) << ")";
TF_RETURN_IF_ERROR(gpu_device->Init(options));
devices->push_back(gpu_device);
@@ -1110,18 +1122,21 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
}
namespace {
-std::unique_ptr<std::map<std::pair<CudaGpuId, CudaGpuId>, bool>>
+std::unique_ptr<std::map<std::pair<PlatformGpuId, PlatformGpuId>, bool>>
GetPeerAccessMap(se::Platform* platform,
- const std::vector<CudaGpuId>& visible_gpu_order) {
- std::unique_ptr<std::map<std::pair<CudaGpuId, CudaGpuId>, bool>> map(
- new std::map<std::pair<CudaGpuId, CudaGpuId>, bool>);
- for (CudaGpuId cuda_gpu_i : visible_gpu_order) {
- for (CudaGpuId cuda_gpu_j : visible_gpu_order) {
+ const std::vector<PlatformGpuId>& visible_gpu_order) {
+ std::unique_ptr<std::map<std::pair<PlatformGpuId, PlatformGpuId>, bool>> map(
+ new std::map<std::pair<PlatformGpuId, PlatformGpuId>, bool>);
+ for (PlatformGpuId platform_gpu_i : visible_gpu_order) {
+ for (PlatformGpuId platform_gpu_j : visible_gpu_order) {
se::StreamExecutor* from =
- GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_i).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_i)
+ .ValueOrDie();
se::StreamExecutor* to =
- GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_j).ValueOrDie();
- (*map)[{cuda_gpu_i, cuda_gpu_j}] = from->CanEnablePeerAccessTo(to);
+ GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_j)
+ .ValueOrDie();
+ (*map)[{platform_gpu_i, platform_gpu_j}] =
+ from->CanEnablePeerAccessTo(to);
}
}
@@ -1131,19 +1146,19 @@ GetPeerAccessMap(se::Platform* platform,
} // namespace
Status BaseGPUDeviceFactory::GetInterconnectMaps(
- const std::vector<CudaGpuId>& visible_gpu_order, se::Platform* gpu_manager,
- std::vector<InterconnectMap>* maps) {
+ const std::vector<PlatformGpuId>& visible_gpu_order,
+ se::Platform* gpu_manager, std::vector<InterconnectMap>* maps) {
// The default interconnect map is obtained from the StreamExecutor.
auto access_map = GetPeerAccessMap(gpu_manager, visible_gpu_order);
maps->resize(1);
InterconnectMap& imap = maps->at(0);
imap.name = "StreamExecutor";
imap.strength = InterconnectMap::kStreamExecutorStrength;
- for (CudaGpuId cuda_id_i : visible_gpu_order) {
- for (CudaGpuId cuda_id_j : visible_gpu_order) {
- if (cuda_id_i == cuda_id_j) continue;
- if ((*access_map)[{cuda_id_i, cuda_id_j}]) {
- imap.directed_links.insert({cuda_id_i, cuda_id_j});
+ for (PlatformGpuId gpu_id_i : visible_gpu_order) {
+ for (PlatformGpuId gpu_id_j : visible_gpu_order) {
+ if (gpu_id_i == gpu_id_j) continue;
+ if ((*access_map)[{gpu_id_i, gpu_id_j}]) {
+ imap.directed_links.insert({gpu_id_i, gpu_id_j});
}
}
}
@@ -1158,13 +1173,14 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
all_tf_gpu_ids.push_back(TfGpuId(i));
}
for (TfGpuId tf_gpu_id : all_tf_gpu_ids) {
- CudaGpuId cuda_gpu_id;
- TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
+ PlatformGpuId platform_gpu_id;
+ TF_RETURN_IF_ERROR(
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
// Get GPU bus_id from its reported NUMA affinity. Because GPUs are
// virtualized in some environments, we can't just use the GPU id.
// NUMA locales are indexed from 0, buses are indexed from 1.
se::StreamExecutor* se =
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
const se::DeviceDescription& desc = se->GetDeviceDescription();
int numa_node = desc.numa_node();
if (numa_node < 0) {
@@ -1174,7 +1190,8 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
// may run into trouble later with data transfer operations. The
// trouble may manifest as slower than expected performance, or
// outright failures.
- LOG(INFO) << "Could not identify NUMA node of CUDA gpu id " << cuda_gpu_id
+ LOG(INFO) << "Could not identify NUMA node of platform GPU id "
+ << platform_gpu_id
<< ", defaulting to 0. Your kernel may not have been built "
<< "with NUMA support.";
numa_node = 0;
@@ -1187,10 +1204,10 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
LocalLinks* links = dev_locality.mutable_links();
for (const InterconnectMap& imap : interconnects) {
for (TfGpuId tf_gpu_dst : all_tf_gpu_ids) {
- CudaGpuId cuda_gpu_dst;
+ PlatformGpuId platform_gpu_dst;
TF_RETURN_IF_ERROR(
- GpuIdManager::TfToCudaGpuId(tf_gpu_dst, &cuda_gpu_dst));
- if (imap.directed_links.find({cuda_gpu_id, cuda_gpu_dst}) !=
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_dst, &platform_gpu_dst));
+ if (imap.directed_links.find({platform_gpu_id, platform_gpu_dst}) !=
imap.directed_links.end()) {
InterconnectLink* ilink = links->add_link();
ilink->set_device_id(tf_gpu_dst.value());
@@ -1204,10 +1221,10 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
// add high strength links to the others.
for (TfGpuId tf_gpu_dst : all_tf_gpu_ids) {
if (tf_gpu_id == tf_gpu_dst) continue;
- CudaGpuId cuda_gpu_dst;
+ PlatformGpuId platform_gpu_dst;
TF_RETURN_IF_ERROR(
- GpuIdManager::TfToCudaGpuId(tf_gpu_dst, &cuda_gpu_dst));
- if (cuda_gpu_id == cuda_gpu_dst) {
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_dst, &platform_gpu_dst));
+ if (platform_gpu_id == platform_gpu_dst) {
InterconnectLink* ilink = links->add_link();
ilink->set_device_id(tf_gpu_dst.value());
ilink->set_type("SAME_DEVICE");
@@ -1216,9 +1233,9 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
}
(*localities)[tf_gpu_id] = dev_locality;
- VLOG(1) << "GPUDevice CudaGpuId " << cuda_gpu_id << " TfGpuId " << tf_gpu_id
- << " on bus " << dev_locality.bus_id() << " numa: " << numa_node
- << " pci: " << desc.pci_bus_id()
+ VLOG(1) << "GPUDevice PlatformGpuId " << platform_gpu_id << " TfGpuId "
+ << tf_gpu_id << " on bus " << dev_locality.bus_id()
+ << " numa: " << numa_node << " pci: " << desc.pci_bus_id()
<< " DeviceLocality: " << dev_locality.DebugString();
}
return Status::OK();
@@ -1226,14 +1243,14 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
static int GetDefaultMinGPUMultiprocessorCount(
se::Platform* gpu_manager,
- const std::vector<CudaGpuId>& visible_gpu_order) {
+ const std::vector<PlatformGpuId>& visible_gpu_order) {
static const int kDefaultMinGPUMultiprocessorCount = 8;
// Find the highest multi-processor count across all visible GPUs.
int max_count = -1;
for (int i = 0; i < visible_gpu_order.size(); ++i) {
auto exec_status =
- GpuIdUtil::ExecutorForCudaGpuId(gpu_manager, visible_gpu_order[i]);
+ GpuIdUtil::ExecutorForPlatformGpuId(gpu_manager, visible_gpu_order[i]);
if (!exec_status.ok()) {
continue;
}
@@ -1252,7 +1269,7 @@ static int GetDefaultMinGPUMultiprocessorCount(
static int GetMinGPUMultiprocessorCount(
se::Platform* gpu_manager,
- const std::vector<CudaGpuId>& visible_gpu_order) {
+ const std::vector<PlatformGpuId>& visible_gpu_order) {
const char* tf_min_gpu_core_count = getenv("TF_MIN_GPU_MULTIPROCESSOR_COUNT");
if (tf_min_gpu_core_count == nullptr ||
@@ -1330,18 +1347,20 @@ std::vector<CudaVersion> GetSupportedCudaComputeCapabilities() {
}
Status EnablePeerAccess(se::Platform* platform,
- const std::vector<CudaGpuId>& visible_gpu_order) {
+ const std::vector<PlatformGpuId>& visible_gpu_order) {
int possible_peer_count = 0;
int enabled_peer_count = 0;
for (int i = 0; i < visible_gpu_order.size(); ++i) {
- const CudaGpuId cuda_gpu_i = visible_gpu_order[i];
+ const PlatformGpuId platform_gpu_i = visible_gpu_order[i];
for (int j = 0; j < visible_gpu_order.size(); ++j) {
- const CudaGpuId cuda_gpu_j = visible_gpu_order[j];
+ const PlatformGpuId platform_gpu_j = visible_gpu_order[j];
// We have already validated that ExecutorForDevice() calls return OK.
se::StreamExecutor* from =
- GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_i).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_i)
+ .ValueOrDie();
se::StreamExecutor* to =
- GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_j).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_j)
+ .ValueOrDie();
if (from->CanEnablePeerAccessTo(to)) {
++possible_peer_count;
@@ -1349,7 +1368,8 @@ Status EnablePeerAccess(se::Platform* platform,
if (!status.ok()) {
LOG(WARNING)
<< "Unable to enable peer access between device ordinals "
- << cuda_gpu_i << " and " << cuda_gpu_j << ", status: " << status;
+ << platform_gpu_i << " and " << platform_gpu_j
+ << ", status: " << status;
} else {
++enabled_peer_count;
}
@@ -1372,22 +1392,23 @@ Status EnablePeerAccess(se::Platform* platform,
} // namespace
Status BaseGPUDeviceFactory::GetValidDeviceIds(
- const std::vector<CudaGpuId>& visible_gpu_order,
- std::vector<CudaGpuId>* ids) {
+ const std::vector<PlatformGpuId>& visible_gpu_order,
+ std::vector<PlatformGpuId>* ids) {
se::Platform* gpu_manager = GPUMachineManager();
bool new_gpu_found = false;
for (int i = 0; i < visible_gpu_order.size(); ++i) {
- const CudaGpuId cuda_gpu_id = visible_gpu_order[i];
+ const PlatformGpuId visible_gpu_id = visible_gpu_order[i];
- // Only perform this once per visible cuda gpu id.
- if (visible_gpu_initialized_[cuda_gpu_id.value()]) {
+ // Only perform this once per visible platform gpu id.
+ if (visible_gpu_initialized_[visible_gpu_id.value()]) {
continue;
}
- visible_gpu_initialized_[cuda_gpu_id.value()] = true;
+ visible_gpu_initialized_[visible_gpu_id.value()] = true;
new_gpu_found = true;
- auto executor = GpuIdUtil::ExecutorForCudaGpuId(gpu_manager, cuda_gpu_id);
+ auto executor =
+ GpuIdUtil::ExecutorForPlatformGpuId(gpu_manager, visible_gpu_id);
if (!executor.ok()) {
return executor.status();
}
@@ -1435,9 +1456,9 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
// Filter out devices that don't have the right capability or power.
for (int i = 0; i < visible_gpu_order.size(); ++i) {
- const CudaGpuId visible_gpu_id = visible_gpu_order[i];
+ const PlatformGpuId visible_gpu_id = visible_gpu_order[i];
auto exec_status =
- GpuIdUtil::ExecutorForCudaGpuId(gpu_manager, visible_gpu_id);
+ GpuIdUtil::ExecutorForPlatformGpuId(gpu_manager, visible_gpu_id);
if (!exec_status.ok()) {
LOG(INFO) << "Ignoring visible gpu device " << visible_gpu_id
<< " whose executor is in invalid state: "
@@ -1486,7 +1507,7 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
if (!ids->empty()) {
std::vector<int> raw_ids(ids->size());
std::transform(ids->begin(), ids->end(), raw_ids.begin(),
- [](CudaGpuId id) -> int { return id.value(); });
+ [](PlatformGpuId id) -> int { return id.value(); });
LOG(INFO) << "Adding visible gpu devices: "
<< str_util::Join(raw_ids, ", ");
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h
index 56d03d7a8c..674e8384d5 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.h
@@ -65,6 +65,11 @@ class BaseGPUDevice : public LocalDevice {
// completes.
bool RequiresRecordingAccessedTensors() const override;
+ // GPU kernel execution requires us to use `tracing::ScopedAnnotation()`
+ // rather than `tracing::ScopedActivity()`, in order to relate asynchronously
+ // launched GPU kernels to the OpKernel.
+ bool TraceUsingAnnotations() const { return true; }
+
void ConsumeListOfAccessedTensors(
DeviceContext* device_context,
const TensorReferenceVector& tensor_refs) override;
@@ -86,15 +91,16 @@ class BaseGPUDevice : public LocalDevice {
// The caller owns the returned device.
PerOpGpuDevice* MakeGpuDevice() override;
- void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
- DeviceContext* dc, Allocator* allocator) override;
+ Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
+ DeviceContext* dc,
+ Allocator* allocator) override;
- // Returns the CUDA GPU id of this device within the native driver system;
+ // Returns the platform GPU id of this device within the native driver system;
// e.g., for CUDA this is the ordinal of the GPU within the system.
int gpu_id() const {
- CudaGpuId cuda_gpu_id;
- TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id_, &cuda_gpu_id));
- return cuda_gpu_id.value();
+ PlatformGpuId platform_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id_, &platform_gpu_id));
+ return platform_gpu_id.value();
}
// The executor that provides control for the device; e.g., for CUDA this
@@ -125,6 +131,7 @@ class BaseGPUDevice : public LocalDevice {
class StreamGroupFactory;
gtl::InlinedVector<StreamGroup*, 4> streams_;
+ mutex scratch_init_mutex_;
gtl::InlinedVector<char*, 4> scratch_;
std::vector<GPUDeviceContext*> device_contexts_;
GpuDeviceInfo* gpu_device_info_ = nullptr;
@@ -135,6 +142,9 @@ class BaseGPUDevice : public LocalDevice {
std::unique_ptr<EventMgr> em_;
std::unique_ptr<thread::ThreadPool> thread_pool_;
+ // Initialize scractch buffers used by Eigen.
+ Status InitScratchBuffers();
+
void ReinitializeDevice(OpKernelContext* context, PerOpGpuDevice* device,
int stream_id, Allocator* allocator);
@@ -168,14 +178,14 @@ class BaseGPUDeviceFactory : public DeviceFactory {
int32 strength;
static const int kSameDeviceStrength;
static const int kStreamExecutorStrength;
- std::set<std::pair<CudaGpuId, CudaGpuId>> directed_links;
+ std::set<std::pair<PlatformGpuId, PlatformGpuId>> directed_links;
};
protected:
// Populates *maps with interconnect maps for all local direct access
// pathways between GPUs.
virtual Status GetInterconnectMaps(
- const std::vector<CudaGpuId>& visible_gpu_order,
+ const std::vector<PlatformGpuId>& visible_gpu_order,
se::Platform* gpu_manager, std::vector<InterconnectMap>* maps);
struct TfGpuIdHash {
@@ -207,16 +217,16 @@ class BaseGPUDeviceFactory : public DeviceFactory {
Allocator* gpu_allocator,
Allocator* cpu_allocator) = 0;
- // Returns into 'ids' the list of valid CUDA GPU ids, in the order that
+ // Returns into 'ids' the list of valid platform GPU ids, in the order that
// they should map to TF GPU ids "/device:GPU:0", "/device:GPU:1", etc,
// based upon 'visible_gpu_order' which was generated by parsing
// GPUOptions::visible_device_list which is a comma-separated list of CUDA GPU
// ids.
- Status GetValidDeviceIds(const std::vector<CudaGpuId>& visible_gpu_order,
- std::vector<CudaGpuId>* ids);
+ Status GetValidDeviceIds(const std::vector<PlatformGpuId>& visible_gpu_order,
+ std::vector<PlatformGpuId>* ids);
- // visible_gpu_initialized_[cuda_gpu_id] is true if visible GPU cuda_gpu_id
- // has been initialized by the process.
+ // visible_gpu_initialized_[platform_gpu_id] is true if visible GPU
+ // platform_gpu_id has been initialized by the process.
std::unordered_map<int, bool> visible_gpu_initialized_;
};
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
index daf59f0560..36294094e9 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
@@ -30,18 +30,21 @@ namespace tensorflow {
namespace {
const char* kDeviceNamePrefix = "/job:localhost/replica:0/task:0";
-int64 GetTotalGPUMemory(CudaGpuId gpu_id) {
+int64 GetTotalGPUMemory(PlatformGpuId gpu_id) {
se::StreamExecutor* se =
- GpuIdUtil::ExecutorForCudaGpuId(GPUMachineManager(), gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(GPUMachineManager(), gpu_id)
+ .ValueOrDie();
int64 total_memory, available_memory;
CHECK(se->DeviceMemoryUsage(&available_memory, &total_memory));
return total_memory;
}
-Status GetComputeCapability(CudaGpuId gpu_id, int* cc_major, int* cc_minor) {
+Status GetComputeCapability(PlatformGpuId gpu_id, int* cc_major,
+ int* cc_minor) {
se::StreamExecutor* se =
- GpuIdUtil::ExecutorForCudaGpuId(GPUMachineManager(), gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(GPUMachineManager(), gpu_id)
+ .ValueOrDie();
if (!se->GetDeviceDescription().cuda_compute_capability(cc_major, cc_minor)) {
*cc_major = 0;
*cc_minor = 0;
@@ -223,7 +226,7 @@ TEST_F(GPUDeviceTest, MultipleVirtualDevices) {
// error.
TEST_F(GPUDeviceTest, UnifiedMemoryUnavailableOnPrePascalGpus) {
int cc_major, cc_minor;
- TF_ASSERT_OK(GetComputeCapability(CudaGpuId(0), &cc_major, &cc_minor));
+ TF_ASSERT_OK(GetComputeCapability(PlatformGpuId(0), &cc_major, &cc_minor));
// Exit early while running on Pascal or later GPUs.
if (cc_major >= 6) {
return;
@@ -244,10 +247,10 @@ TEST_F(GPUDeviceTest, UnifiedMemoryUnavailableOnPrePascalGpus) {
// more memory than what is available on the device.
TEST_F(GPUDeviceTest, UnifiedMemoryAllocation) {
static constexpr double kGpuMemoryFraction = 1.2;
- static constexpr CudaGpuId kCudaGpuId(0);
+ static constexpr PlatformGpuId kPlatformGpuId(0);
int cc_major, cc_minor;
- TF_ASSERT_OK(GetComputeCapability(kCudaGpuId, &cc_major, &cc_minor));
+ TF_ASSERT_OK(GetComputeCapability(kPlatformGpuId, &cc_major, &cc_minor));
// Exit early if running on pre-Pascal GPUs.
if (cc_major < 6) {
LOG(INFO)
@@ -262,7 +265,7 @@ TEST_F(GPUDeviceTest, UnifiedMemoryAllocation) {
ASSERT_EQ(1, devices.size());
int64 memory_limit = devices[0]->attributes().memory_limit();
- ASSERT_EQ(memory_limit, static_cast<int64>(GetTotalGPUMemory(kCudaGpuId) *
+ ASSERT_EQ(memory_limit, static_cast<int64>(GetTotalGPUMemory(kPlatformGpuId) *
kGpuMemoryFraction));
AllocatorAttributes allocator_attributes = AllocatorAttributes();
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id.h b/tensorflow/core/common_runtime/gpu/gpu_id.h
index 2a6caea296..f0d9022821 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id.h
@@ -25,10 +25,10 @@ namespace tensorflow {
// physical machine, it can be filtered by CUDA environment variable
// CUDA_VISIBLE_DEVICES. Note that this id is not visible to Tensorflow, but
// result after filtering by CUDA_VISIBLE_DEVICES is visible to TF and is
-// called CUDA GPU id as below. See
+// called platform GPU id as below. See
// http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars
// for more details.
-// - CUDA GPU id (also called *visible* GPU id in
+// - *platform* GPU id (also called *visible* GPU id in
// third_party/tensorflow/core/protobuf/config.proto): this is the id that is
// visible to Tensorflow after filtering by CUDA_VISIBLE_DEVICES, and is
// generated by the CUDA GPU driver. It starts from 0 and is used for CUDA API
@@ -39,14 +39,14 @@ namespace tensorflow {
// field of the device name "/device:GPU:<id>", and is also the identifier of
// a BaseGPUDevice. Note that the configuration allows us to create multiple
// BaseGPUDevice per GPU hardware in order to use multi CUDA streams on the
-// hardware, so the mapping between TF GPU id and CUDA GPU id is not a 1:1
+// hardware, so the mapping between TF GPU id and platform GPU id is not a 1:1
// mapping, see the example below.
//
// For example, assuming that in the machine we have GPU device with index 0, 1,
// 2 and 3 (physical GPU id). Setting "CUDA_VISIBLE_DEVICES=1,2,3" will create
-// the following mapping between CUDA GPU id and physical GPU id:
+// the following mapping between platform GPU id and physical GPU id:
//
-// CUDA GPU id -> physical GPU id
+// platform GPU id -> physical GPU id
// 0 -> 1
// 1 -> 2
// 2 -> 3
@@ -56,32 +56,32 @@ namespace tensorflow {
//
// Assuming we configure the Session to create one BaseGPUDevice per GPU
// hardware, then setting GPUOptions::visible_device_list to "2,0" will create
-// the following mappting between TF GPU id and CUDA GPU id:
+// the following mappting between TF GPU id and platform GPU id:
//
-// TF GPU id -> CUDA GPU ID
+// TF GPU id -> platform GPU ID
// 0 (i.e. /device:GPU:0) -> 2
// 1 (i.e. /device:GPU:1) -> 0
//
-// Note that CUDA GPU id 1 is filtered out by GPUOptions::visible_device_list,
-// so it won't be used by the TF process.
+// Note that platform GPU id 1 is filtered out by
+// GPUOptions::visible_device_list, so it won't be used by the TF process.
//
// On the other hand, if we configure it to create 2 BaseGPUDevice per GPU
// hardware, then setting GPUOptions::visible_device_list to "2,0" will create
-// the following mappting between TF GPU id and CUDA GPU id:
+// the following mappting between TF GPU id and platform GPU id:
//
-// TF GPU id -> CUDA GPU ID
+// TF GPU id -> platform GPU ID
// 0 (i.e. /device:GPU:0) -> 2
// 1 (i.e. /device:GPU:1) -> 2
// 2 (i.e. /device:GPU:2) -> 0
// 3 (i.e. /device:GPU:3) -> 0
//
-// We create strong-typed integer classes for both TF GPU id and CUDA GPU id to
-// minimize programming errors and improve code readability. Except for the
+// We create strong-typed integer classes for both TF GPU id and platform GPU id
+// to minimize programming errors and improve code readability. Except for the
// StreamExecutor interface (as we don't change its API), whenever we need a
-// TF GPU id (or CUDA GPU id) we should use TfGpuId (or CudaGpuId) instead of a
-// raw integer.
+// TF GPU id (or platform GPU id) we should use TfGpuId (or PlatformGpuId)
+// instead of a raw integer.
TF_LIB_GTL_DEFINE_INT_TYPE(TfGpuId, int32);
-TF_LIB_GTL_DEFINE_INT_TYPE(CudaGpuId, int32);
+TF_LIB_GTL_DEFINE_INT_TYPE(PlatformGpuId, int32);
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
index b5099dc8ef..2b40730119 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
@@ -26,26 +26,27 @@ limitations under the License.
namespace tensorflow {
namespace {
-// Manages the map between TfGpuId and CUDA GPU id.
-class TfToCudaGpuIdMap {
+// Manages the map between TfGpuId and platform GPU id.
+class TfToPlatformGpuIdMap {
public:
- static TfToCudaGpuIdMap* singleton() {
- static auto* id_map = new TfToCudaGpuIdMap;
+ static TfToPlatformGpuIdMap* singleton() {
+ static auto* id_map = new TfToPlatformGpuIdMap;
return id_map;
}
- Status Insert(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id) LOCKS_EXCLUDED(mu_) {
+ Status Insert(TfGpuId tf_gpu_id, PlatformGpuId platform_gpu_id)
+ LOCKS_EXCLUDED(mu_) {
std::pair<IdMapType::iterator, bool> result;
{
mutex_lock lock(mu_);
- result = id_map_.insert({tf_gpu_id.value(), cuda_gpu_id.value()});
+ result = id_map_.insert({tf_gpu_id.value(), platform_gpu_id.value()});
}
- if (!result.second && cuda_gpu_id.value() != result.first->second) {
+ if (!result.second && platform_gpu_id.value() != result.first->second) {
return errors::AlreadyExists(
"TensorFlow device (GPU:", tf_gpu_id.value(),
") is being mapped to "
"multiple CUDA devices (",
- cuda_gpu_id.value(), " now, and ", result.first->second,
+ platform_gpu_id.value(), " now, and ", result.first->second,
" previously), which is not supported. "
"This may be the result of providing different GPU configurations "
"(ConfigProto.gpu_options, for example different visible_device_list)"
@@ -56,17 +57,17 @@ class TfToCudaGpuIdMap {
return Status::OK();
}
- bool Find(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) const
+ bool Find(TfGpuId tf_gpu_id, PlatformGpuId* platform_gpu_id) const
LOCKS_EXCLUDED(mu_) {
mutex_lock lock(mu_);
auto result = id_map_.find(tf_gpu_id.value());
if (result == id_map_.end()) return false;
- *cuda_gpu_id = result->second;
+ *platform_gpu_id = result->second;
return true;
}
private:
- TfToCudaGpuIdMap() = default;
+ TfToPlatformGpuIdMap() = default;
void TestOnlyReset() LOCKS_EXCLUDED(mu_) {
mutex_lock lock(mu_);
@@ -78,17 +79,18 @@ class TfToCudaGpuIdMap {
IdMapType id_map_ GUARDED_BY(mu_);
friend class ::tensorflow::GpuIdManager;
- TF_DISALLOW_COPY_AND_ASSIGN(TfToCudaGpuIdMap);
+ TF_DISALLOW_COPY_AND_ASSIGN(TfToPlatformGpuIdMap);
};
} // namespace
-Status GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id,
- CudaGpuId cuda_gpu_id) {
- return TfToCudaGpuIdMap::singleton()->Insert(tf_gpu_id, cuda_gpu_id);
+Status GpuIdManager::InsertTfPlatformGpuIdPair(TfGpuId tf_gpu_id,
+ PlatformGpuId platform_gpu_id) {
+ return TfToPlatformGpuIdMap::singleton()->Insert(tf_gpu_id, platform_gpu_id);
}
-Status GpuIdManager::TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) {
- if (TfToCudaGpuIdMap::singleton()->Find(tf_gpu_id, cuda_gpu_id)) {
+Status GpuIdManager::TfToPlatformGpuId(TfGpuId tf_gpu_id,
+ PlatformGpuId* platform_gpu_id) {
+ if (TfToPlatformGpuIdMap::singleton()->Find(tf_gpu_id, platform_gpu_id)) {
return Status::OK();
}
return errors::NotFound("TensorFlow device GPU:", tf_gpu_id.value(),
@@ -96,7 +98,7 @@ Status GpuIdManager::TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) {
}
void GpuIdManager::TestOnlyReset() {
- TfToCudaGpuIdMap::singleton()->TestOnlyReset();
+ TfToPlatformGpuIdMap::singleton()->TestOnlyReset();
}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h
index 491d92ccdd..62df4310c4 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h
@@ -21,15 +21,17 @@ limitations under the License.
namespace tensorflow {
-// Class that maintains a map from TfGpuId to CudaGpuId, and manages the
+// Class that maintains a map from TfGpuId to PlatformGpuId, and manages the
// translation between them.
class GpuIdManager {
public:
- // Adds a mapping from tf_gpu_id to cuda_gpu_id.
- static Status InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id);
+ // Adds a mapping from tf_gpu_id to platform_gpu_id.
+ static Status InsertTfPlatformGpuIdPair(TfGpuId tf_gpu_id,
+ PlatformGpuId platform_gpu_id);
- // Gets the cuda_gpu_id associated with tf_gpu_id. Returns OK if found.
- static Status TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id);
+ // Gets the platform_gpu_id associated with tf_gpu_id. Returns OK if found.
+ static Status TfToPlatformGpuId(TfGpuId tf_gpu_id,
+ PlatformGpuId* platform_gpu_id);
// Clears the map. Used in unit tests only.
static void TestOnlyReset();
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc b/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc
index a663ec7051..8bf3c6a308 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc
@@ -22,38 +22,38 @@ limitations under the License.
namespace tensorflow {
namespace {
-CudaGpuId TfToCudaGpuId(TfGpuId tf) {
- CudaGpuId cuda;
- TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf, &cuda));
- return cuda;
+PlatformGpuId TfToPlatformGpuId(TfGpuId tf) {
+ PlatformGpuId platform_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf, &platform_gpu_id));
+ return platform_gpu_id;
}
TEST(GpuIdManagerTest, Basics) {
TfGpuId key_0(0);
- CudaGpuId value_0(0);
- TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_0, value_0));
- EXPECT_EQ(value_0, TfToCudaGpuId(key_0));
+ PlatformGpuId value_0(0);
+ TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_0, value_0));
+ EXPECT_EQ(value_0, TfToPlatformGpuId(key_0));
// Multiple calls to map the same value is ok.
- TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_0, value_0));
- EXPECT_EQ(value_0, TfToCudaGpuId(key_0));
+ TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_0, value_0));
+ EXPECT_EQ(value_0, TfToPlatformGpuId(key_0));
// Map a different TfGpuId to a different value.
TfGpuId key_1(3);
- CudaGpuId value_1(2);
- TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_1, value_1));
- EXPECT_EQ(value_1, TfToCudaGpuId(key_1));
+ PlatformGpuId value_1(2);
+ TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_1, value_1));
+ EXPECT_EQ(value_1, TfToPlatformGpuId(key_1));
// Mapping a different TfGpuId to the same value is ok.
TfGpuId key_2(10);
- TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_2, value_1));
- EXPECT_EQ(value_1, TfToCudaGpuId(key_2));
+ TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_2, value_1));
+ EXPECT_EQ(value_1, TfToPlatformGpuId(key_2));
// Mapping the same TfGpuId to a different value.
- ASSERT_FALSE(GpuIdManager::InsertTfCudaGpuIdPair(key_2, value_0).ok());
+ ASSERT_FALSE(GpuIdManager::InsertTfPlatformGpuIdPair(key_2, value_0).ok());
// Getting a nonexistent mapping.
- ASSERT_FALSE(GpuIdManager::TfToCudaGpuId(TfGpuId(100), &value_0).ok());
+ ASSERT_FALSE(GpuIdManager::TfToPlatformGpuId(TfGpuId(100), &value_0).ok());
}
} // namespace
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
index b9c66b3328..b1f10fb1dc 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
@@ -24,34 +24,37 @@ limitations under the License.
namespace tensorflow {
-// Utility methods for translation between Tensorflow GPU ids and CUDA GPU ids.
+// Utility methods for translation between Tensorflow GPU ids and platform GPU
+// ids.
class GpuIdUtil {
public:
// Convenient methods for getting the associated executor given a TfGpuId or
- // CudaGpuId.
- static se::port::StatusOr<se::StreamExecutor*> ExecutorForCudaGpuId(
- se::Platform* gpu_manager, CudaGpuId cuda_gpu_id) {
- return gpu_manager->ExecutorForDevice(cuda_gpu_id.value());
+ // PlatformGpuId.
+ static se::port::StatusOr<se::StreamExecutor*> ExecutorForPlatformGpuId(
+ se::Platform* gpu_manager, PlatformGpuId platform_gpu_id) {
+ return gpu_manager->ExecutorForDevice(platform_gpu_id.value());
}
- static se::port::StatusOr<se::StreamExecutor*> ExecutorForCudaGpuId(
- CudaGpuId cuda_gpu_id) {
- return ExecutorForCudaGpuId(GPUMachineManager(), cuda_gpu_id);
+ static se::port::StatusOr<se::StreamExecutor*> ExecutorForPlatformGpuId(
+ PlatformGpuId platform_gpu_id) {
+ return ExecutorForPlatformGpuId(GPUMachineManager(), platform_gpu_id);
}
static se::port::StatusOr<se::StreamExecutor*> ExecutorForTfGpuId(
TfGpuId tf_gpu_id) {
- CudaGpuId cuda_gpu_id;
- TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
- return ExecutorForCudaGpuId(cuda_gpu_id);
+ PlatformGpuId platform_gpu_id;
+ TF_RETURN_IF_ERROR(
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
+ return ExecutorForPlatformGpuId(platform_gpu_id);
}
- // Verify that the cuda_gpu_id associated with a TfGpuId is legitimate.
+ // Verify that the platform_gpu_id associated with a TfGpuId is legitimate.
static void CheckValidTfGpuId(TfGpuId tf_gpu_id) {
- CudaGpuId cuda_gpu_id;
- TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
+ PlatformGpuId platform_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
const int visible_device_count = GPUMachineManager()->VisibleDeviceCount();
- CHECK_LT(cuda_gpu_id.value(), visible_device_count)
- << "cuda_gpu_id is outside discovered device range."
- << " TF GPU id: " << tf_gpu_id << " CUDA GPU id: " << cuda_gpu_id
+ CHECK_LT(platform_gpu_id.value(), visible_device_count)
+ << "platform_gpu_id is outside discovered device range."
+ << " TF GPU id: " << tf_gpu_id
+ << " platform GPU id: " << platform_gpu_id
<< " visible device count: " << visible_device_count;
}
};
diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
index b18688174d..3e95374fda 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
@@ -76,12 +76,16 @@ GPUProcessState::GPUProcessState() : gpu_device_enabled_(false) {
// This function is defined for debugging problems with the allocators.
GPUProcessState::~GPUProcessState() {
CHECK_EQ(this, instance_);
- for (auto p : gpu_allocators_) {
- delete p;
- }
instance_ = nullptr;
}
+int GPUProcessState::BusIdForGPU(TfGpuId tf_gpu_id) {
+ // Return the NUMA node associated with the GPU's StreamExecutor.
+ se::StreamExecutor* se =
+ GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id).ValueOrDie();
+ return se->GetDeviceDescription().numa_node();
+}
+
Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options,
TfGpuId tf_gpu_id,
size_t total_bytes) {
@@ -93,64 +97,63 @@ Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options,
if (tf_gpu_id.value() >= static_cast<int64>(gpu_allocators_.size())) {
gpu_allocators_.resize(tf_gpu_id.value() + 1);
- if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types)
- gpu_al_.resize(tf_gpu_id.value() + 1);
}
- if (gpu_allocators_[tf_gpu_id.value()] == nullptr) {
- VisitableAllocator* gpu_allocator;
-
+ AllocatorParts& allocator_parts = gpu_allocators_[tf_gpu_id.value()];
+ if (allocator_parts.allocator.get() == nullptr) {
// Validate allocator types.
if (!allocator_type.empty() && allocator_type != "BFC") {
LOG(ERROR) << "Invalid allocator type: " << allocator_type;
return nullptr;
}
- CudaGpuId cuda_gpu_id;
- TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
- gpu_allocator =
- new GPUBFCAllocator(cuda_gpu_id, total_bytes, options,
+ PlatformGpuId platform_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
+ int bus_id = BusIdForGPU(tf_gpu_id);
+ while (bus_id >= gpu_visitors_.size()) {
+ gpu_visitors_.push_back({});
+ }
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id,
+ (options.per_process_gpu_memory_fraction() > 1.0 ||
+ options.experimental().use_unified_memory()),
+ gpu_visitors_[bus_id], {});
+ Allocator* gpu_allocator =
+ new GPUBFCAllocator(sub_allocator, total_bytes, options,
strings::StrCat("GPU_", tf_gpu_id.value(), "_bfc"));
// If true, checks for memory overwrites by writing
// distinctive patterns on both ends of allocated memory.
if (useCudaMemoryGuardAllocator()) {
- gpu_allocator = new GPUDebugAllocator(gpu_allocator, cuda_gpu_id);
- gpu_allocator = new GPUNanResetAllocator(gpu_allocator, cuda_gpu_id);
+ gpu_allocator = new GPUDebugAllocator(gpu_allocator, platform_gpu_id);
+ gpu_allocator = new GPUNanResetAllocator(gpu_allocator, platform_gpu_id);
} else if (useCudaMallocAllocator()) {
// If true, passes all allocation requests through to cudaMalloc
// useful for doing memory debugging with tools like cuda-memcheck
// **WARNING** probably will not work in a multi-gpu scenario
- gpu_allocator = new GPUcudaMallocAllocator(gpu_allocator, cuda_gpu_id);
- }
- gpu_allocators_[tf_gpu_id.value()] = gpu_allocator;
-
- // If there are any pending AllocVisitors for this bus, add
- // them now.
- se::StreamExecutor* se =
- GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id).ValueOrDie();
- int bus_id = se->GetDeviceDescription().numa_node();
- if (bus_id >= 0 && bus_id < static_cast<int64>(gpu_visitors_.size())) {
- for (const auto& v : gpu_visitors_[bus_id]) {
- gpu_allocator->AddAllocVisitor(v);
- }
+ gpu_allocator =
+ new GPUcudaMallocAllocator(gpu_allocator, platform_gpu_id);
}
+
+ Allocator* recording_allocator = nullptr;
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
ProcessState::MemDesc md;
md.loc = ProcessState::MemDesc::GPU;
- md.dev_index = cuda_gpu_id.value();
+ md.dev_index = platform_gpu_id.value();
md.gpu_registered = false;
md.nic_registered = true;
- if (static_cast<int64>(gpu_al_.size()) <= tf_gpu_id.value()) {
- gpu_al_.resize(tf_gpu_id.value() + 1);
- }
- gpu_al_[tf_gpu_id.value()] = new internal::RecordingAllocator(
+ recording_allocator = new internal::RecordingAllocator(
&process_state_->mem_desc_map_, gpu_allocator, md, &mu_);
}
+ allocator_parts = {std::unique_ptr<Allocator>(gpu_allocator), sub_allocator,
+ std::unique_ptr<Allocator>(recording_allocator)};
+ }
+ if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
+ return allocator_parts.recording_allocator.get();
+ } else {
+ return allocator_parts.allocator.get();
}
- if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types)
- return gpu_al_[tf_gpu_id.value()];
- return gpu_allocators_[tf_gpu_id.value()];
#else
LOG(FATAL) << "GPUAllocator unavailable. Not compiled with --config=cuda.";
return nullptr;
@@ -172,11 +175,12 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
tf_shared_lock lock(mu_);
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types &&
- static_cast<int>(cuda_al_.size()) > 0) {
- return cuda_al_[0];
+ !cuda_host_allocators_.empty() &&
+ cuda_host_allocators_[0].recording_allocator != nullptr) {
+ return cuda_host_allocators_[0].recording_allocator.get();
}
if (static_cast<int>(cuda_host_allocators_.size()) > numa_node) {
- return cuda_host_allocators_[0];
+ return cuda_host_allocators_[0].allocator.get();
}
}
@@ -190,7 +194,7 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
// it knows is valid.
se::StreamExecutor* se = nullptr;
for (int i = 0; i < static_cast<int>(gpu_allocators_.size()); ++i) {
- if (gpu_allocators_[i] != nullptr) {
+ if (gpu_allocators_[i].allocator != nullptr) {
se = GpuIdUtil::ExecutorForTfGpuId(TfGpuId(i)).ValueOrDie();
break;
}
@@ -199,6 +203,15 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
CHECK_NE(nullptr, se);
while (static_cast<int>(cuda_host_allocators_.size()) <= numa_node) {
+ while (cuda_host_alloc_visitors_.size() <= numa_node) {
+ cuda_host_alloc_visitors_.push_back({});
+ }
+ while (cuda_host_free_visitors_.size() <= numa_node) {
+ cuda_host_free_visitors_.push_back({});
+ }
+ SubAllocator* sub_allocator = new CUDAHostAllocator(
+ se, numa_node, cuda_host_alloc_visitors_[numa_node],
+ cuda_host_free_visitors_[numa_node]);
// TODO(zheng-xq): evaluate whether 64GB by default is the best choice.
int64 cuda_host_mem_limit_in_mb = -1;
Status status = ReadInt64FromEnvVar("TF_CUDA_HOST_MEM_LIMIT_IN_MB",
@@ -208,62 +221,92 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
LOG(ERROR) << "GetCUDAHostAllocator: " << status.error_message();
}
int64 cuda_host_mem_limit = cuda_host_mem_limit_in_mb * (1LL << 20);
- VisitableAllocator* allocator =
- new BFCAllocator(new CUDAHostAllocator(se), cuda_host_mem_limit,
+ Allocator* allocator =
+ new BFCAllocator(sub_allocator, cuda_host_mem_limit,
true /*allow_growth*/, "cuda_host_bfc" /*name*/);
- if (LogMemory::IsEnabled()) {
+ if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) {
// Wrap the allocator to track allocation ids for better logging
// at the cost of performance.
- allocator = new TrackingVisitableAllocator(allocator, true);
+ allocator = new TrackingAllocator(allocator, true);
}
- cuda_host_allocators_.push_back(allocator);
+ cuda_host_allocators_.push_back({std::unique_ptr<Allocator>(allocator),
+ sub_allocator,
+ std::unique_ptr<Allocator>(nullptr)});
+ AllocatorParts& allocator_parts = cuda_host_allocators_.back();
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
ProcessState::MemDesc md;
md.loc = ProcessState::MemDesc::CPU;
md.dev_index = 0;
md.gpu_registered = true;
md.nic_registered = false;
- cuda_al_.push_back(new internal::RecordingAllocator(
- &process_state_->mem_desc_map_, cuda_host_allocators_.back(), md,
- &mu_));
+ allocator_parts.recording_allocator.reset(
+ new internal::RecordingAllocator(&process_state_->mem_desc_map_,
+ allocator_parts.allocator.get(), md,
+ &mu_));
}
}
- if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types)
- return cuda_al_[0];
- return cuda_host_allocators_[0];
+ if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
+ return cuda_host_allocators_[0].recording_allocator.get();
+ } else {
+ return cuda_host_allocators_[0].allocator.get();
+ }
}
void GPUProcessState::AddGPUAllocVisitor(int bus_id,
- const AllocVisitor& visitor) {
- CHECK(process_state_);
+ const SubAllocator::Visitor& visitor) {
#if GOOGLE_CUDA
mutex_lock lock(mu_);
- for (int i = 0; i < static_cast<int64>(gpu_allocators_.size()); ++i) {
- se::StreamExecutor* se =
- GpuIdUtil::ExecutorForTfGpuId(TfGpuId(i)).ValueOrDie();
- if (gpu_allocators_[i] &&
- (se->GetDeviceDescription().numa_node() + 1) == bus_id) {
- gpu_allocators_[i]->AddAllocVisitor(visitor);
- }
- }
+ CHECK(gpu_allocators_.empty()) // Crash OK
+ << "AddGPUAllocVisitor must be called before "
+ "first call to GetGPUAllocator.";
while (bus_id >= static_cast<int64>(gpu_visitors_.size())) {
- gpu_visitors_.push_back(std::vector<AllocVisitor>());
+ gpu_visitors_.push_back(std::vector<SubAllocator::Visitor>());
}
gpu_visitors_[bus_id].push_back(visitor);
#endif // GOOGLE_CUDA
}
+void GPUProcessState::AddCUDAHostAllocVisitor(
+ int numa_node, const SubAllocator::Visitor& visitor) {
+#if GOOGLE_CUDA
+ mutex_lock lock(mu_);
+ CHECK(cuda_host_allocators_.empty()) // Crash OK
+ << "AddCUDAHostAllocVisitor must be called before "
+ "first call to GetCUDAHostAllocator.";
+ while (numa_node >= static_cast<int64>(cuda_host_alloc_visitors_.size())) {
+ cuda_host_alloc_visitors_.push_back(std::vector<SubAllocator::Visitor>());
+ }
+ cuda_host_alloc_visitors_[numa_node].push_back(visitor);
+#endif // GOOGLE_CUDA
+}
+
+void GPUProcessState::AddCUDAHostFreeVisitor(
+ int numa_node, const SubAllocator::Visitor& visitor) {
+#if GOOGLE_CUDA
+ mutex_lock lock(mu_);
+ CHECK(cuda_host_allocators_.empty()) // Crash OK
+ << "AddCUDAHostFreeVisitor must be called before "
+ "first call to GetCUDAHostAllocator.";
+ while (numa_node >= static_cast<int64>(cuda_host_free_visitors_.size())) {
+ cuda_host_free_visitors_.push_back(std::vector<SubAllocator::Visitor>());
+ }
+ cuda_host_free_visitors_[numa_node].push_back(visitor);
+#endif // GOOGLE_CUDA
+}
+
void GPUProcessState::TestOnlyReset() {
- process_state_->ProcessState::TestOnlyReset();
+ if (process_state_) {
+ process_state_->ProcessState::TestOnlyReset();
+ }
{
mutex_lock lock(mu_);
gpu_device_enabled_ = false;
+ gpu_allocators_.clear();
gpu_visitors_.clear();
- gtl::STLDeleteElements(&gpu_allocators_);
- gtl::STLDeleteElements(&cuda_host_allocators_);
- gtl::STLDeleteElements(&gpu_al_);
- gtl::STLDeleteElements(&cuda_al_);
+ cuda_host_allocators_.clear();
+ cuda_host_alloc_visitors_.clear();
+ cuda_host_free_visitors_.clear();
}
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.h b/tensorflow/core/common_runtime/gpu/gpu_process_state.h
index cb41c3c6bd..43e9a31660 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_process_state.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.h
@@ -32,7 +32,6 @@ limitations under the License.
namespace tensorflow {
class Allocator;
-class VisitableAllocator;
class PoolAllocator;
// Singleton that manages per-process state when GPUs are present.
@@ -72,18 +71,30 @@ class GPUProcessState {
virtual Allocator* GetCUDAHostAllocator(int numa_node);
- // Registers a function to be called once on every new Region
- // allocated by every GPURegionAllocator proximate to the specified
- // bus. The AllocVisitor is provided with a memory pointer and the
- // size of the area it identifies. The pointer is not guaranteed to
- // be valid after the call terminates. The intention is for this
- // interface to be used for network device memory registration.
- // "bus_id" is platform-specific. On many platforms it
- // should be 0. On machines with multiple PCIe buses, it should be
- // the index of one of the PCIe buses. If the bus_id is invalid,
- // results are undefined.
- typedef std::function<void(void*, size_t)> AllocVisitor;
- virtual void AddGPUAllocVisitor(int bus_id, const AllocVisitor& visitor);
+ // Registers a Visitor to be invoked on new chunks of memory allocated by the
+ // SubAllocator of every GPU proximate to the specified bus. The AllocVisitor
+ // is provided with a memory pointer, a GPU id, and the size of the area it
+ // identifies. The pointer is not guaranteed to be valid after the call
+ // terminates. The intention is for this interface to be used for network
+ // device memory registration. "bus_id" is platform-specific. On many
+ // platforms it should be 0. On machines with multiple PCIe buses, it should
+ // be the index of one of the PCIe buses (maybe the NUMA node at which the
+ // PCIe is rooted). If the bus_id is invalid, results are undefined.
+ virtual void AddGPUAllocVisitor(int bus_id,
+ const SubAllocator::Visitor& visitor);
+
+ // Registers a Visitor to be invoked on new chunks of memory allocated by
+ // the SubAllocator of the CUDAHostAllocator for the given numa_node.
+ virtual void AddCUDAHostAllocVisitor(int numa_node,
+ const SubAllocator::Visitor& visitor);
+
+ // Registers a Visitor to be invoked on each chunk handed back for freeing to
+ // the SubAllocator of the CUDAHostAllocator for the given numa_node.
+ virtual void AddCUDAHostFreeVisitor(int numa_node,
+ const SubAllocator::Visitor& visitor);
+
+ // Returns bus_id for the given GPU id.
+ virtual int BusIdForGPU(TfGpuId tf_gpu_id);
protected:
GPUProcessState();
@@ -103,16 +114,21 @@ class GPUProcessState {
mutex mu_;
- std::vector<VisitableAllocator*> gpu_allocators_ GUARDED_BY(mu_);
- std::vector<std::vector<AllocVisitor>> gpu_visitors_ GUARDED_BY(mu_);
- std::vector<Allocator*> cuda_host_allocators_ GUARDED_BY(mu_);
+ struct AllocatorParts {
+ std::unique_ptr<Allocator> allocator;
+ SubAllocator* sub_allocator; // owned by allocator
+ std::unique_ptr<Allocator> recording_allocator;
+ };
+ std::vector<AllocatorParts> gpu_allocators_ GUARDED_BY(mu_);
+ std::vector<std::vector<SubAllocator::Visitor>> gpu_visitors_ GUARDED_BY(mu_);
- virtual ~GPUProcessState();
+ std::vector<AllocatorParts> cuda_host_allocators_ GUARDED_BY(mu_);
+ std::vector<std::vector<SubAllocator::Visitor>> cuda_host_alloc_visitors_
+ GUARDED_BY(mu_);
+ std::vector<std::vector<SubAllocator::Visitor>> cuda_host_free_visitors_
+ GUARDED_BY(mu_);
- // Optional RecordingAllocators that wrap the corresponding
- // Allocators for runtime attribute use analysis.
- std::vector<Allocator*> gpu_al_ GUARDED_BY(mu_);
- std::vector<Allocator*> cuda_al_ GUARDED_BY(mu_);
+ virtual ~GPUProcessState();
friend class GPUDeviceTest;
};
diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
index 583bff2c07..6b2f6547b0 100644
--- a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
@@ -31,7 +31,8 @@ TEST(PoolAllocatorTest, ZeroSizeBuffers) {
2 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie()),
+ .ValueOrDie(),
+ 0 /*numa_node*/, {}, {}),
new NoopRounder, "pool");
EXPECT_EQ(nullptr, pool.AllocateRaw(4 /*alignment*/, 0 /*num_bytes*/));
@@ -49,7 +50,8 @@ TEST(PoolAllocatorTest, ZeroSizePool) {
0 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie()),
+ .ValueOrDie(),
+ 0 /*numa_node*/, {}, {}),
new NoopRounder, "pool");
EXPECT_EQ(0, pool.get_from_pool_count());
@@ -82,7 +84,8 @@ TEST(PoolAllocatorTest, Alignment) {
0 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie()),
+ .ValueOrDie(),
+ 0 /*numa_node*/, {}, {}),
new NoopRounder, "pool");
for (int i = 0; i < 16; ++i) {
size_t alignment = 1 << i;
@@ -97,8 +100,8 @@ TEST(PoolAllocatorTest, Alignment) {
TEST(PoolAllocatorTest, AutoResize) {
PoolAllocator pool(2 /*pool_size_limit*/, true /*auto_resize*/,
- new BasicCPUAllocator(0 /*numa_node*/), new NoopRounder,
- "pool");
+ new BasicCPUAllocator(0 /*numa_node*/, {}, {}),
+ new NoopRounder, "pool");
// Alloc/dealloc 10 sizes just a few times, confirming pool size
// stays at 2.
@@ -123,14 +126,32 @@ TEST(PoolAllocatorTest, AutoResize) {
}
TEST(PoolAllocatorTest, CudaHostAllocator) {
+ int alloc_count = 0;
+ int64 alloc_size = 0;
+ SubAllocator::Visitor alloc_visitor =
+ [&alloc_count, &alloc_size](void* ptr, int numa_node, int64 size) {
+ ++alloc_count;
+ alloc_size += size;
+ };
+ int free_count = 0;
+ int64 free_size = 0;
+ SubAllocator::Visitor free_visitor =
+ [&free_count, &free_size](void* ptr, int numa_node, int64 size) {
+ ++free_count;
+ free_size += size;
+ };
se::Platform* platform =
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
- PoolAllocator pool(
- 2 /*pool_size_limit*/, false /*auto_resize*/,
- new CUDAHostAllocator(
- platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie()),
- new NoopRounder, "pool");
+ CUDAHostAllocator* sub_allocator = new CUDAHostAllocator(
+ platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
+ .ValueOrDie(),
+ 0 /*numa_node*/, {alloc_visitor}, {free_visitor});
+ PoolAllocator pool(2 /*pool_size_limit*/, false /*auto_resize*/,
+ sub_allocator, new NoopRounder, "pool");
+ EXPECT_EQ(0, alloc_count);
+ EXPECT_EQ(0, alloc_size);
+ EXPECT_EQ(0, free_count);
+ EXPECT_EQ(0, free_size);
// Repeatedly Get a 16-byte value, confirming that there's only
// one real allocation.
@@ -138,6 +159,10 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
EXPECT_EQ(0, pool.get_from_pool_count());
EXPECT_EQ(1, pool.allocated_count());
EXPECT_NE(nullptr, p1_16);
+ EXPECT_EQ(1, alloc_count); // Underlying suballoc of 16 bytes
+ // Each suballocation includes a 16B ChunkPrefix.
+ static const int kChunkPrefixSize = 16;
+ EXPECT_EQ(16 + (alloc_count * kChunkPrefixSize), alloc_size);
pool.DeallocateRaw(p1_16);
// Pool contents {16}
EXPECT_EQ(1, pool.put_count());
@@ -148,6 +173,9 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
pool.DeallocateRaw(p2_16); // Put it back.
// Pool contents {16}
EXPECT_EQ(2, pool.put_count());
+ EXPECT_EQ(1, alloc_count); // Underlying suballoc of 16 bytes
+ EXPECT_EQ(16 + (alloc_count * kChunkPrefixSize), alloc_size);
+ EXPECT_EQ(0, free_count);
// Get two more values of different sizes.
void* p3_4 = pool.AllocateRaw(4, 4);
@@ -160,6 +188,9 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
void* p4_2 = pool.AllocateRaw(4, 2); // Get a third size buffer.
EXPECT_NE(nullptr, p4_2);
EXPECT_EQ(0, pool.evicted_count());
+ EXPECT_EQ(3, alloc_count);
+ EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size);
+ EXPECT_EQ(0, free_count);
// The pool is full: when we put back p4_2, the 16-byte buffer
// should be evicted since it was least recently inserted.
@@ -167,6 +198,10 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
// Pool contents {2, 4}
EXPECT_EQ(4, pool.put_count());
EXPECT_EQ(1, pool.evicted_count());
+ EXPECT_EQ(3, alloc_count);
+ EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size);
+ EXPECT_EQ(1, free_count);
+ EXPECT_EQ(16 + (free_count * kChunkPrefixSize), free_size);
// Re-getting and putting size 2 or 4 should not alter pool size or
// num-evicted.
@@ -180,12 +215,20 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
EXPECT_EQ(6, pool.put_count());
EXPECT_EQ(3, pool.allocated_count());
EXPECT_EQ(1, pool.evicted_count());
+ EXPECT_EQ(3, alloc_count);
+ EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size);
+ EXPECT_EQ(1, free_count);
+ EXPECT_EQ(16 + (free_count * kChunkPrefixSize), free_size);
pool.Clear();
EXPECT_EQ(0, pool.get_from_pool_count());
EXPECT_EQ(0, pool.put_count());
EXPECT_EQ(0, pool.allocated_count());
EXPECT_EQ(0, pool.evicted_count());
+ EXPECT_EQ(3, alloc_count);
+ EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size);
+ EXPECT_EQ(3, free_count);
+ EXPECT_EQ(16 + 4 + 2 + (free_count * kChunkPrefixSize), free_size);
}
TEST(PoolAllocatorTest, Pow2Rounder) {
@@ -206,7 +249,8 @@ TEST(PoolAllocatorTest, Name) {
2 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie()),
+ .ValueOrDie(),
+ 0 /*numa_node*/, {}, {}),
new NoopRounder, "pool");
EXPECT_EQ("pool", pool.Name());
}
diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc
index 96ecfb41d4..37a979a8f1 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.cc
+++ b/tensorflow/core/common_runtime/graph_optimizer.cc
@@ -38,7 +38,8 @@ void GraphOptimizer::Optimize(
std::unique_ptr<Graph>* graph,
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
- const std::function<bool(const Node*)>& cse_consider_fn) {
+ const std::function<bool(const Node*)>& cse_consider_fn,
+ const std::function<bool(const Node*)>& cf_consider_fn) {
Graph* g = graph->get();
DumpGraph("Initial", g);
@@ -62,6 +63,7 @@ void GraphOptimizer::Optimize(
if (opts_.do_constant_folding()) {
ConstantFoldingOptions cf_opts;
cf_opts.shape_map = shape_map;
+ cf_opts.consider = cf_consider_fn;
if (opts_.max_folded_constant_in_bytes() > 0) {
cf_opts.max_constant_size_in_bytes =
opts_.max_folded_constant_in_bytes();
diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h
index 80246281cd..789cc56942 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.h
+++ b/tensorflow/core/common_runtime/graph_optimizer.h
@@ -45,12 +45,15 @@ class GraphOptimizer {
//
// If cse_consider_fn is not null then only nodes for which cse_consider_fn
// returns true will be considered for CSE.
+ // If cf_consider_fn is not null then only nodes for which cf_consider_fn
+ // returns true will be considered for CF.
void Optimize(
FunctionLibraryRuntime* runtime, Env* env, Device* device,
std::unique_ptr<Graph>* graph,
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
- const std::function<bool(const Node*)>& cse_consider_fn = nullptr);
+ const std::function<bool(const Node*)>& cse_consider_fn = nullptr,
+ const std::function<bool(const Node*)>& cf_consider_fn = nullptr);
const OptimizerOptions& options() { return opts_; }
diff --git a/tensorflow/core/common_runtime/local_device.cc b/tensorflow/core/common_runtime/local_device.cc
index db5022d56e..873182371e 100644
--- a/tensorflow/core/common_runtime/local_device.cc
+++ b/tensorflow/core/common_runtime/local_device.cc
@@ -62,7 +62,7 @@ struct LocalDevice::EigenThreadPoolInfo {
LocalDevice::LocalDevice(const SessionOptions& options,
const DeviceAttributes& attributes)
- : TracingDevice(options.env, attributes), owned_tp_info_(nullptr) {
+ : Device(options.env, attributes), owned_tp_info_(nullptr) {
// Log info messages if TensorFlow is not compiled with instructions that
// could speed up performance and are available on the current CPU.
port::InfoAboutUnusedCPUFeatures();
diff --git a/tensorflow/core/common_runtime/local_device.h b/tensorflow/core/common_runtime/local_device.h
index 9a82fb7204..226f121bf3 100644
--- a/tensorflow/core/common_runtime/local_device.h
+++ b/tensorflow/core/common_runtime/local_device.h
@@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOCAL_DEVICE_H_
#include "tensorflow/core/common_runtime/device.h"
-#include "tensorflow/core/common_runtime/tracing_device.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/platform/macros.h"
@@ -32,7 +31,7 @@ struct SessionOptions;
// initializes a shared Eigen compute device used by both. This
// should eventually be removed once we refactor ThreadPoolDevice and
// GPUDevice into more 'process-wide' abstractions.
-class LocalDevice : public TracingDevice {
+class LocalDevice : public Device {
public:
LocalDevice(const SessionOptions& options,
const DeviceAttributes& attributes);
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
index df9c3a686c..429b19599b 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
@@ -23,12 +23,11 @@ limitations under the License.
#include <cstdlib>
#include "tensorflow/core/common_runtime/bfc_allocator.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
-#include "tensorflow/core/framework/allocator_registry.h"
+#include "tensorflow/core/common_runtime/pool_allocator.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/mem.h"
-#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/numa.h"
#ifndef INTEL_MKL_DNN_ONLY
#include "i_malloc.h"
@@ -40,20 +39,16 @@ typedef unsigned int uint;
namespace tensorflow {
-class MklSubAllocator : public SubAllocator {
+class MklSubAllocator : public BasicCPUAllocator {
public:
+ MklSubAllocator() : BasicCPUAllocator(port::kNUMANoAffinity, {}, {}) {}
~MklSubAllocator() override {}
-
- void* Alloc(size_t alignment, size_t num_bytes) override {
- return port::AlignedMalloc(num_bytes, alignment);
- }
- void Free(void* ptr, size_t num_bytes) override { port::AlignedFree(ptr); }
};
// CPU allocator that handles small-size allocations by calling
// suballocator directly. Mostly, it is just a wrapper around a suballocator
// (that calls malloc and free directly) with support for bookkeeping.
-class MklSmallSizeAllocator : public VisitableAllocator {
+class MklSmallSizeAllocator : public Allocator {
public:
MklSmallSizeAllocator(SubAllocator* sub_allocator, size_t total_memory,
const string& name)
@@ -75,10 +70,6 @@ class MklSmallSizeAllocator : public VisitableAllocator {
CHECK(map_.insert(map_val).second);
// Increment statistics for small-size allocations.
IncrementStats(num_bytes);
- // Call alloc visitors.
- for (const auto& visitor : alloc_visitors_) {
- visitor(ptr, num_bytes);
- }
}
return ptr;
}
@@ -94,9 +85,6 @@ class MklSmallSizeAllocator : public VisitableAllocator {
if (map_iter != map_.end()) {
// Call free visitors.
size_t dealloc_bytes = map_iter->second;
- for (const auto& visitor : free_visitors_) {
- visitor(ptr, dealloc_bytes);
- }
sub_allocator_->Free(ptr, dealloc_bytes);
DecrementStats(dealloc_bytes);
map_.erase(map_iter);
@@ -121,16 +109,6 @@ class MklSmallSizeAllocator : public VisitableAllocator {
stats_.Clear();
}
- void AddAllocVisitor(Visitor visitor) override {
- mutex_lock l(mutex_);
- alloc_visitors_.push_back(visitor);
- }
-
- void AddFreeVisitor(Visitor visitor) override {
- mutex_lock l(mutex_);
- free_visitors_.push_back(visitor);
- }
-
private:
// Increment statistics for the allocator handling small allocations.
inline void IncrementStats(size_t alloc_size)
@@ -163,15 +141,11 @@ class MklSmallSizeAllocator : public VisitableAllocator {
// Allocator stats for small allocs
AllocatorStats stats_ GUARDED_BY(mutex_);
-
- // Visitors
- std::vector<Visitor> alloc_visitors_ GUARDED_BY(mutex_);
- std::vector<Visitor> free_visitors_ GUARDED_BY(mutex_);
};
/// CPU allocator for MKL that wraps BFC allocator and intercepts
/// and redirects memory allocation calls from MKL.
-class MklCPUAllocator : public VisitableAllocator {
+class MklCPUAllocator : public Allocator {
public:
// Constructor and other standard functions
@@ -277,6 +251,7 @@ class MklCPUAllocator : public VisitableAllocator {
// max_alloc_size from large_size_allocator would be the maximum
// size allocated by MklCPUAllocator.
stats->max_alloc_size = l_stats.max_alloc_size;
+ stats->bytes_limit = std::max(s_stats.bytes_limit, l_stats.bytes_limit);
}
void ClearStats() override {
@@ -284,16 +259,6 @@ class MklCPUAllocator : public VisitableAllocator {
large_size_allocator_->ClearStats();
}
- void AddAllocVisitor(Visitor visitor) override {
- small_size_allocator_->AddAllocVisitor(visitor);
- large_size_allocator_->AddAllocVisitor(visitor);
- }
-
- void AddFreeVisitor(Visitor visitor) override {
- small_size_allocator_->AddFreeVisitor(visitor);
- large_size_allocator_->AddFreeVisitor(visitor);
- }
-
private:
// Hooks provided by this allocator for memory allocation routines from MKL
@@ -330,7 +295,7 @@ class MklCPUAllocator : public VisitableAllocator {
// The alignment that we need for the allocations
static constexpr const size_t kAlignment = 64;
- VisitableAllocator* large_size_allocator_; // owned by this class
+ Allocator* large_size_allocator_; // owned by this class
MklSmallSizeAllocator* small_size_allocator_; // owned by this class.
SubAllocator* sub_allocator_; // not owned by this class
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc b/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc
index a67411cd2e..e08ab57638 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/common_runtime/mkl_cpu_allocator.h"
@@ -50,4 +50,4 @@ TEST(MKLBFCAllocatorTest, TestMaxLimit) {
} // namespace tensorflow
-#endif // INTEL_MKL
+#endif // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
index f9f36443a8..6af4ca4d96 100644
--- a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
+++ b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
@@ -50,8 +50,8 @@ class ParallelConcatRemovePass : public GraphOptimizationPass {
}
for (Node* n : matches) {
AttrSlice n_attrs = n->attrs();
- auto base_make_node = [n, g, &n_attrs](const string& op,
- const string& name) {
+ auto base_make_node = [n, &n_attrs](const string& op,
+ const string& name) {
NodeBuilder node_builder(name, op);
node_builder.Device(n->requested_device());
string colo;
@@ -60,7 +60,7 @@ class ParallelConcatRemovePass : public GraphOptimizationPass {
}
return node_builder;
};
- auto make_node = [n, g, &n_attrs, &base_make_node](string op) {
+ auto make_node = [n, g, &base_make_node](string op) {
return base_make_node(
op, g->NewName(strings::StrCat(n->name(), "/Internal")));
};
diff --git a/tensorflow/core/common_runtime/pool_allocator.cc b/tensorflow/core/common_runtime/pool_allocator.cc
index fdad8de8d6..66dc8f3322 100644
--- a/tensorflow/core/common_runtime/pool_allocator.cc
+++ b/tensorflow/core/common_runtime/pool_allocator.cc
@@ -40,8 +40,7 @@ PoolAllocator::PoolAllocator(size_t pool_size_limit, bool auto_resize,
auto_resize_(auto_resize),
pool_size_limit_(pool_size_limit),
allocator_(allocator),
- size_rounder_(size_rounder),
- allocation_begun_(false) {
+ size_rounder_(size_rounder) {
if (auto_resize) {
CHECK_LT(size_t{0}, pool_size_limit)
<< "size limit must be > 0 if auto_resize is true.";
@@ -93,7 +92,6 @@ ChunkPrefix* FindPrefix(void* user_ptr) {
} // namespace
void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
- if (!allocation_begun_) allocation_begun_ = true;
if (num_bytes == 0) return nullptr;
// If alignment is larger than kPoolAlignment, increase num_bytes so that we
@@ -129,9 +127,6 @@ void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
return PrepareChunk(r, alignment, num_bytes);
} else {
void* ptr = allocator_->Alloc(kPoolAlignment, num_bytes);
- for (const auto& v : alloc_visitors_) {
- v(ptr, num_bytes);
- }
return PrepareChunk(ptr, alignment, num_bytes);
}
}
@@ -141,9 +136,6 @@ void PoolAllocator::DeallocateRaw(void* ptr) {
ChunkPrefix* cp = FindPrefix(ptr);
CHECK_LE((void*)cp, (void*)ptr);
if (!has_size_limit_ && !auto_resize_) {
- for (const auto& v : free_visitors_) {
- v(cp, cp->num_bytes);
- }
allocator_->Free(cp, cp->num_bytes);
} else {
mutex_lock lock(mutex_);
@@ -164,9 +156,6 @@ void PoolAllocator::Clear() {
mutex_lock lock(mutex_);
for (auto iter : pool_) {
PtrRecord* pr = iter.second;
- for (const auto& v : free_visitors_) {
- v(pr->ptr, pr->num_bytes);
- }
allocator_->Free(pr->ptr, pr->num_bytes);
delete pr;
}
@@ -221,9 +210,6 @@ void PoolAllocator::EvictOne() {
DCHECK(iter != pool_.end());
}
pool_.erase(iter);
- for (const auto& v : free_visitors_) {
- v(prec->ptr, prec->num_bytes);
- }
allocator_->Free(prec->ptr, prec->num_bytes);
delete prec;
++evicted_count_;
@@ -269,28 +255,19 @@ void PoolAllocator::EvictOne() {
}
}
-void PoolAllocator::AddAllocVisitor(Visitor visitor) {
- mutex_lock lock(mutex_);
- CHECK(!allocation_begun_)
- << "AddAllocVisitor may not be called after pool allocation "
- << "has begun.";
- alloc_visitors_.push_back(visitor);
-}
-
-void PoolAllocator::AddFreeVisitor(Visitor visitor) {
- mutex_lock lock(mutex_);
- CHECK(!allocation_begun_)
- << "AddFreeVisitor may not be called after pool allocation "
- << "has begun.";
- free_visitors_.push_back(visitor);
-}
-
void* BasicCPUAllocator::Alloc(size_t alignment, size_t num_bytes) {
- return port::AlignedMalloc(num_bytes, static_cast<int>(alignment));
+ void* ptr = nullptr;
+ if (num_bytes > 0) {
+ ptr = port::AlignedMalloc(num_bytes, static_cast<int>(alignment));
+ VisitAlloc(ptr, numa_node_, num_bytes);
+ }
+ return ptr;
}
void BasicCPUAllocator::Free(void* ptr, size_t num_bytes) {
- port::AlignedFree(ptr);
+ if (num_bytes > 0) {
+ VisitFree(ptr, numa_node_, num_bytes);
+ port::AlignedFree(ptr);
+ }
}
-
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/pool_allocator.h b/tensorflow/core/common_runtime/pool_allocator.h
index 607734445b..5b4623ba10 100644
--- a/tensorflow/core/common_runtime/pool_allocator.h
+++ b/tensorflow/core/common_runtime/pool_allocator.h
@@ -16,14 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_POOL_ALLOCATOR_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_POOL_ALLOCATOR_H_
-// Simple LRU pool allocators for various flavors of CPU RAM that
-// implement the VisitableAllocator interface.
+// Simple LRU pool allocators for various flavors of CPU RAM.
#include <atomic>
#include <map>
#include <memory>
#include <vector>
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -41,7 +40,7 @@ class RoundUpInterface {
// Size-limited pool of memory buffers obtained from a SubAllocator
// instance. Pool eviction policy is LRU.
-class PoolAllocator : public VisitableAllocator {
+class PoolAllocator : public Allocator {
public:
// "pool_size_limit" is the maximum number of returned, re-usable
// memory buffers to keep in the pool. If pool_size_limit == 0, the
@@ -64,14 +63,6 @@ class PoolAllocator : public VisitableAllocator {
void DeallocateRaw(void* ptr) override;
- // REQUIRES: The following functions may only be called prior
- // to the first Allocate*() call. Once allocation has begun, it is
- // illegal to register another visitor.
-
- void AddAllocVisitor(Visitor visitor) override;
-
- void AddFreeVisitor(Visitor visitor) override;
-
// Allocate an unused memory region of size "num_bytes". Fetch from
// the pool if available, otherwise call allocator_.
void* Get(size_t num_bytes);
@@ -141,12 +132,6 @@ class PoolAllocator : public VisitableAllocator {
int64 put_count_ GUARDED_BY(mutex_) = 0;
int64 allocated_count_ GUARDED_BY(mutex_) = 0;
int64 evicted_count_ GUARDED_BY(mutex_) = 0;
- // Write access to these is guarded by mutex_, but not read
- // access. They may only be modified prior to the first
- // allocation. Later attempts to modify will fail.
- std::vector<Visitor> alloc_visitors_;
- std::vector<Visitor> free_visitors_;
- std::atomic<bool> allocation_begun_;
};
// Do-nothing rounder. Passes through sizes unchanged.
@@ -166,7 +151,9 @@ class Pow2Rounder : public RoundUpInterface {
class BasicCPUAllocator : public SubAllocator {
public:
// Argument numa_node is currently ignored.
- explicit BasicCPUAllocator(int numa_node) : numa_node_(numa_node) {}
+ BasicCPUAllocator(int numa_node, const std::vector<Visitor>& alloc_visitors,
+ const std::vector<Visitor>& free_visitors)
+ : SubAllocator(alloc_visitors, free_visitors), numa_node_(numa_node) {}
~BasicCPUAllocator() override {}
@@ -176,6 +163,8 @@ class BasicCPUAllocator : public SubAllocator {
private:
int numa_node_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BasicCPUAllocator);
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/process_state.cc b/tensorflow/core/common_runtime/process_state.cc
index 447338e7bd..bcaa37fc8a 100644
--- a/tensorflow/core/common_runtime/process_state.cc
+++ b/tensorflow/core/common_runtime/process_state.cc
@@ -71,20 +71,28 @@ ProcessState::MemDesc ProcessState::PtrType(const void* ptr) {
return MemDesc();
}
-VisitableAllocator* ProcessState::GetCPUAllocator(int numa_node) {
+Allocator* ProcessState::GetCPUAllocator(int numa_node) {
CHECK_GE(numa_node, 0);
if (!numa_enabled_) numa_node = 0;
mutex_lock lock(mu_);
while (cpu_allocators_.size() <= static_cast<size_t>(numa_node)) {
+ // If visitors have been defined we need an Allocator built from
+ // a SubAllocator. Prefer BFCAllocator, but fall back to PoolAllocator
+ // depending on env var setting.
+ const bool alloc_visitors_defined =
+ (!cpu_alloc_visitors_.empty() || !cpu_free_visitors_.empty());
bool use_bfc_allocator = false;
- // TODO(reedwm): Switch default to BGFAllocator if it's at least as fast and
- // efficient.
- Status status = ReadBoolFromEnvVar("TF_CPU_ALLOCATOR_USE_BFC", false,
- &use_bfc_allocator);
+ Status status = ReadBoolFromEnvVar(
+ "TF_CPU_ALLOCATOR_USE_BFC", alloc_visitors_defined, &use_bfc_allocator);
if (!status.ok()) {
LOG(ERROR) << "GetCPUAllocator: " << status.error_message();
}
- VisitableAllocator* allocator;
+ Allocator* allocator = nullptr;
+ SubAllocator* sub_allocator =
+ (alloc_visitors_defined || use_bfc_allocator)
+ ? new BasicCPUAllocator(numa_enabled_ ? numa_node : -1,
+ cpu_alloc_visitors_, cpu_free_visitors_)
+ : nullptr;
if (use_bfc_allocator) {
// TODO(reedwm): evaluate whether 64GB by default is the best choice.
int64 cpu_mem_limit_in_mb = -1;
@@ -95,34 +103,63 @@ VisitableAllocator* ProcessState::GetCPUAllocator(int numa_node) {
LOG(ERROR) << "GetCPUAllocator: " << status.error_message();
}
int64 cpu_mem_limit = cpu_mem_limit_in_mb * (1LL << 20);
- allocator = new BFCAllocator(
- new BasicCPUAllocator(numa_enabled_ ? numa_node : -1), cpu_mem_limit,
- true /*allow_growth*/, "bfc_cpu_allocator_for_gpu" /*name*/);
+ DCHECK(sub_allocator);
+ allocator =
+ new BFCAllocator(sub_allocator, cpu_mem_limit, true /*allow_growth*/,
+ "bfc_cpu_allocator_for_gpu" /*name*/);
VLOG(2) << "Using BFCAllocator with memory limit of "
<< cpu_mem_limit_in_mb << " MB for ProcessState CPU allocator";
- } else {
- allocator = new PoolAllocator(
- 100 /*pool_size_limit*/, true /*auto_resize*/,
- new BasicCPUAllocator(numa_enabled_ ? numa_node : -1),
- new NoopRounder, "cpu_pool");
+ } else if (alloc_visitors_defined) {
+ DCHECK(sub_allocator);
+ allocator =
+ new PoolAllocator(100 /*pool_size_limit*/, true /*auto_resize*/,
+ sub_allocator, new NoopRounder, "cpu_pool");
VLOG(2) << "Using PoolAllocator for ProcessState CPU allocator "
<< "numa_enabled_=" << numa_enabled_
<< " numa_node=" << numa_node;
+ } else {
+ DCHECK(!sub_allocator);
+ allocator = cpu_allocator();
}
- if (LogMemory::IsEnabled()) {
+ if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) {
// Wrap the allocator to track allocation ids for better logging
// at the cost of performance.
- allocator = new TrackingVisitableAllocator(allocator, true);
+ allocator = new TrackingAllocator(allocator, true);
}
cpu_allocators_.push_back(allocator);
+ if (!sub_allocator) {
+ DCHECK(cpu_alloc_visitors_.empty() && cpu_free_visitors_.empty());
+ }
}
return cpu_allocators_[numa_node];
}
+void ProcessState::AddCPUAllocVisitor(SubAllocator::Visitor visitor) {
+ VLOG(1) << "AddCPUAllocVisitor";
+ mutex_lock lock(mu_);
+ CHECK_EQ(0, cpu_allocators_.size()) // Crash OK
+ << "AddCPUAllocVisitor must be called prior to first call to "
+ "ProcessState::GetCPUAllocator";
+ cpu_alloc_visitors_.push_back(std::move(visitor));
+}
+
+void ProcessState::AddCPUFreeVisitor(SubAllocator::Visitor visitor) {
+ mutex_lock lock(mu_);
+ CHECK_EQ(0, cpu_allocators_.size()) // Crash OK
+ << "AddCPUFreeVisitor must be called prior to first call to "
+ "ProcessState::GetCPUAllocator";
+ cpu_free_visitors_.push_back(std::move(visitor));
+}
+
void ProcessState::TestOnlyReset() {
mutex_lock lock(mu_);
+ // Don't delete this value because it's static.
+ Allocator* default_cpu_allocator = cpu_allocator();
mem_desc_map_.clear();
- gtl::STLDeleteElements(&cpu_allocators_);
+ for (Allocator* a : cpu_allocators_) {
+ if (a != default_cpu_allocator) delete a;
+ }
+ cpu_allocators_.clear();
gtl::STLDeleteElements(&cpu_al_);
}
diff --git a/tensorflow/core/common_runtime/process_state.h b/tensorflow/core/common_runtime/process_state.h
index 2892677333..cac312d849 100644
--- a/tensorflow/core/common_runtime/process_state.h
+++ b/tensorflow/core/common_runtime/process_state.h
@@ -30,7 +30,6 @@ limitations under the License.
namespace tensorflow {
class Allocator;
-class VisitableAllocator;
class PoolAllocator;
// Singleton that manages per-process state, e.g. allocation of
@@ -65,7 +64,15 @@ class ProcessState {
// Returns the one CPUAllocator used for the given numa_node.
// TEMPORARY: ignores numa_node.
- VisitableAllocator* GetCPUAllocator(int numa_node);
+ Allocator* GetCPUAllocator(int numa_node);
+
+ // Registers alloc visitor for the CPU allocator(s).
+ // REQUIRES: must be called before GetCPUAllocator.
+ void AddCPUAllocVisitor(SubAllocator::Visitor v);
+
+ // Registers free visitor for the CPU allocator(s).
+ // REQUIRES: must be called before GetCPUAllocator.
+ void AddCPUFreeVisitor(SubAllocator::Visitor v);
typedef std::unordered_map<const void*, MemDesc> MDMap;
@@ -87,7 +94,9 @@ class ProcessState {
mutex mu_;
- std::vector<VisitableAllocator*> cpu_allocators_ GUARDED_BY(mu_);
+ std::vector<Allocator*> cpu_allocators_ GUARDED_BY(mu_);
+ std::vector<SubAllocator::Visitor> cpu_alloc_visitors_ GUARDED_BY(mu_);
+ std::vector<SubAllocator::Visitor> cpu_free_visitors_ GUARDED_BY(mu_);
virtual ~ProcessState();
diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h
index 103eee03b3..c00789a556 100644
--- a/tensorflow/core/common_runtime/renamed_device.h
+++ b/tensorflow/core/common_runtime/renamed_device.h
@@ -58,6 +58,15 @@ class RenamedDevice : public Device {
return underlying_->GetAllocator(attr);
}
+ Allocator* GetScopedAllocator(AllocatorAttributes attr,
+ int64 step_id) override {
+ return underlying_->GetScopedAllocator(attr, step_id);
+ }
+
+ ScopedAllocatorMgr* GetScopedAllocatorMgr() const override {
+ return underlying_->GetScopedAllocatorMgr();
+ }
+
const Eigen::ThreadPoolDevice* eigen_cpu_device() override {
return underlying_->eigen_cpu_device();
}
@@ -72,9 +81,10 @@ class RenamedDevice : public Device {
return underlying_->MakeGpuDevice();
}
- void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
- DeviceContext* dc, Allocator* allocator) override {
- underlying_->ReinitializeGpuDevice(context, device, dc, allocator);
+ Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
+ DeviceContext* dc,
+ Allocator* allocator) override {
+ return underlying_->ReinitializeGpuDevice(context, device, dc, allocator);
}
Status MakeTensorFromProto(const TensorProto& tensor_proto,
diff --git a/tensorflow/core/common_runtime/ring_reducer.cc b/tensorflow/core/common_runtime/ring_reducer.cc
index a81f8650bf..b1fe928ba7 100644
--- a/tensorflow/core/common_runtime/ring_reducer.cc
+++ b/tensorflow/core/common_runtime/ring_reducer.cc
@@ -41,6 +41,16 @@ limitations under the License.
// Set true for greater intelligibility of debug mode log messages.
#define READABLE_KEYS false
+// RingReduce algorithm exchanges chunks of tensor between devices. The chunk
+// size depends on the number of subdivisions specified in the algorithm. If
+// the user does not specify the number of subdivisions, we infer the number
+// dynamically so that the resulting chunk size does not exceed
+// kMaxChunkSizeBytes, empirically set at 4 MiB.
+constexpr size_t kMaxChunkSizeBytes = (4 * 1024 * 1024);
+// kMaxSubdivsPerDev is used to give an upper bound on the number of
+// subdivisions dynamically generated. A reasonable value would be a small
+// multiple of the number of NICs adjacent to each device.
+constexpr int kMaxSubdivsPerDevice = 2;
namespace tensorflow {
namespace {
@@ -92,7 +102,62 @@ RingReducer::RingReducer()
RingReducer::~RingReducer() { group_size_tensor_ready_.WaitForNotification(); }
+Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) {
+ if (col_params->instance.shape.num_elements() == 0) {
+ return errors::Internal("shape in CollectiveParams should be non-empty");
+ }
+ const int kAvgDevPerTask =
+ col_params->group.group_size / col_params->group.num_tasks;
+ const int kMaxNumSubdivs = kMaxSubdivsPerDevice * kAvgDevPerTask;
+ if (kMaxNumSubdivs <= 0) {
+ return errors::Internal("Unexpected kMaxNumSubdivs ", kMaxNumSubdivs,
+ " in RingReducer");
+ }
+ // NOTE(ayushd): If no subdiv_offsets have been specified, dynamically add
+ // as many offsets as needed so that the size of tensor chunks <=
+ // kMaxChunkSizeBytes. Empirically, chunks that are too small or too large
+ // lead to worse performance.
+ int num_subdivs = 0;
+ const size_t tensor_size = col_params->instance.shape.num_elements() *
+ DataTypeSize(col_params->instance.data_type);
+ size_t chunk_size;
+ do {
+ ++num_subdivs;
+ int num_chunks = col_params->group.group_size * num_subdivs;
+ chunk_size = tensor_size / num_chunks;
+ VLOG(2) << "num_subdivs " << num_subdivs << " num_chunks " << num_chunks
+ << " chunk_size " << chunk_size;
+ } while (chunk_size > kMaxChunkSizeBytes && num_subdivs < kMaxNumSubdivs);
+ if (num_subdivs <= 0) {
+ return errors::Internal("Unexpected num_subdivs ", num_subdivs,
+ " in RingReducer");
+ }
+
+ int subdiv_stride = kAvgDevPerTask / num_subdivs;
+ if (subdiv_stride == 0) subdiv_stride = 1;
+ col_params->instance.impl_details.subdiv_offsets.reserve(num_subdivs);
+ for (int sdi = 0; sdi < num_subdivs; ++sdi) {
+ int subdiv_offset = subdiv_stride * sdi;
+ if (sdi % 2 == 1) subdiv_offset *= -1;
+ col_params->instance.impl_details.subdiv_offsets.push_back(subdiv_offset);
+ }
+
+ if (VLOG_IS_ON(2)) {
+ string subdiv_buf;
+ for (const int subdiv_offset :
+ col_params->instance.impl_details.subdiv_offsets) {
+ strings::StrAppend(&subdiv_buf, " ", subdiv_offset);
+ }
+ VLOG(2) << "Dynamically generated " << num_subdivs
+ << " subdiv_offsets:" << subdiv_buf << " tensor_size "
+ << tensor_size << " chunk_size " << chunk_size;
+ }
+
+ return Status::OK();
+}
+
Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) {
+ // TODO(b/113171733): change CHECKs to return errors.
CHECK_EQ(col_params->instance.type, REDUCTION_COLLECTIVE);
CHECK_EQ(col_params->instance.impl_details.collective_name, "RingReduce");
const string& device_name =
@@ -123,12 +188,11 @@ Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) {
dev_per_task.push_back(dev_count);
CHECK_EQ(col_params->group.num_tasks, dev_per_task.size());
- // Generate a ring permutation for each requested offset.
if (col_params->instance.impl_details.subdiv_offsets.empty()) {
- return errors::Internal(
- "Subdiv offsets should be non-empty for ring reducer, size=",
- col_params->instance.impl_details.subdiv_offsets.size());
+ TF_RETURN_IF_ERROR(GenerateSubdivsInCollectiveParams(col_params));
}
+
+ // Generate a ring permutation for requested offset.
VLOG(2) << "Setting up perms for col_params " << col_params
<< " subdiv_permutations "
<< &col_params->instance.impl_details.subdiv_permutations;
@@ -646,7 +710,8 @@ bool RingReducer::RunAsyncParts() {
case RF_SEND:
--send_pending_count;
break;
- default: {} // Ignore any other actions
+ default: {
+ } // Ignore any other actions
}
}
}
diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc
index 28df85399e..75aba43572 100644
--- a/tensorflow/core/common_runtime/ring_reducer_test.cc
+++ b/tensorflow/core/common_runtime/ring_reducer_test.cc
@@ -549,37 +549,38 @@ class RingReducerTest : public ::testing::Test {
int32 reduce_counter_ GUARDED_BY(mu_) = 0;
};
-TEST_F(RingReducerTest, InitializeParams) {
- static const int kNumDevsPerTask = 8;
- static const int kNumTasks = 3;
- static const int kNumDevs = kNumDevsPerTask * kNumTasks;
+CollectiveParams SetUpCollectiveParams(const int num_devs_per_task,
+ const int num_tasks) {
CollectiveParams cp;
- std::vector<string> device_names;
- std::vector<string> task_names;
+ const int kNumDevs = num_devs_per_task * num_tasks;
cp.group.group_key = 1;
cp.group.group_size = kNumDevs;
cp.group.device_type = DeviceType("GPU");
- cp.group.num_tasks = kNumTasks;
+ cp.group.num_tasks = num_tasks;
cp.instance.instance_key = 3;
cp.instance.type = REDUCTION_COLLECTIVE;
cp.instance.data_type = DataType(DT_FLOAT);
- cp.instance.shape = TensorShape({5});
+ cp.instance.shape = TensorShape({kNumDevs});
cp.instance.impl_details.collective_name = "RingReduce";
cp.instance.impl_details.subdiv_offsets.push_back(0);
cp.is_source = false;
for (int i = 0; i < kNumDevs; ++i) {
- int task_id = i / kNumDevsPerTask;
- int dev_id = i % kNumDevsPerTask;
+ int task_id = i / num_devs_per_task;
+ int dev_id = i % num_devs_per_task;
string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id);
- task_names.push_back(task_name);
string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id);
- device_names.push_back(device_name);
cp.instance.task_names.push_back(task_name);
cp.instance.device_names.push_back(device_name);
}
+ return cp;
+}
- int test_rank = 0;
- cp.default_rank = test_rank;
+TEST_F(RingReducerTest, InitializeParams) {
+ const int kNumDevsPerTask = 8;
+ const int kNumTasks = 3;
+ CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
+
+ cp.default_rank = 0;
cp.instance.impl_details.subdiv_offsets = {0, 4};
RunSubdivPermsTest(&cp,
{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
@@ -588,8 +589,15 @@ TEST_F(RingReducerTest, InitializeParams) {
8, 9, 10, 11, 20, 21, 22, 23, 16, 17, 18, 19}},
{0, 4});
- test_rank = 3;
- cp.default_rank = test_rank;
+ cp.instance.impl_details.subdiv_offsets = {0, -4};
+ RunSubdivPermsTest(&cp,
+ {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
+ {3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8,
+ 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20}},
+ {0, 3});
+
+ cp.default_rank = 3;
cp.instance.impl_details.subdiv_offsets = {3, -3};
RunSubdivPermsTest(&cp,
{{3, 4, 5, 6, 7, 0, 1, 2, 11, 12, 13, 14,
@@ -599,6 +607,49 @@ TEST_F(RingReducerTest, InitializeParams) {
{0, 1});
}
+TEST_F(RingReducerTest, AutomaticSubdivs) {
+ const int kNumDevsPerTask = 8;
+ const int kNumTasks = 3;
+ const int kNumDevs = kNumDevsPerTask * kNumTasks;
+ CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
+
+ // Test automatic generation of subdiv offsets.
+ cp.default_rank = 0;
+ cp.instance.impl_details.subdiv_offsets.clear();
+ RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
+ {0});
+
+ // Set shape so that with 2 subdivs chunk_size is 3 MiB. This should cause 2
+ // offsets, {0, -4}, to be generated.
+ {
+ int num_subdivs = 2;
+ int num_chunks = kNumDevs * num_subdivs;
+ size_t chunk_size = 3 * 1048576; // 3 MB
+ size_t tensor_size = chunk_size * num_chunks;
+ cp.instance.shape =
+ TensorShape({static_cast<int64>(tensor_size / DataTypeSize(DT_FLOAT))});
+ }
+ cp.instance.impl_details.subdiv_offsets.clear();
+ RunSubdivPermsTest(&cp,
+ {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
+ {3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8,
+ 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20}},
+ {0, 3});
+}
+
+TEST_F(RingReducerTest, AutomaticSubdivUpperBound) {
+ const int kNumDevsPerTask = 1;
+ const int kNumTasks = 4;
+ CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
+
+ cp.default_rank = 0;
+ cp.instance.impl_details.subdiv_offsets.clear();
+ cp.instance.shape = TensorShape({104857600 / DataTypeSize(DT_FLOAT)});
+ RunSubdivPermsTest(&cp, {{0, 1, 2, 3}, {0, 1, 2, 3}}, {0, 0});
+}
+
// TODO(b/113171733): change to use TEST_P.
#define DEF_TEST(B, T, W, D, S, L, A) \
TEST_F(RingReducerTest, \
diff --git a/tensorflow/core/common_runtime/session_ref.cc b/tensorflow/core/common_runtime/session_ref.cc
deleted file mode 100644
index b931ef4229..0000000000
--- a/tensorflow/core/common_runtime/session_ref.cc
+++ /dev/null
@@ -1,170 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/common_runtime/session_ref.h"
-
-#include <utility>
-
-namespace tensorflow {
-
-namespace {
-
-// Scope helper to track active calls and manage session lifetime.
-struct RunCounter {
- std::shared_ptr<Session> session;
- uint64* value;
- mutex* m;
- condition_variable* cv;
-
- explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m,
- condition_variable* cv)
- : session(std::move(s)), value(v), m(m), cv(cv) {
- mutex_lock l(*m);
- ++*value;
- }
-
- ~RunCounter() {
- mutex_lock l(*m);
- if (--*value == 0) {
- cv->notify_all();
- }
- }
-};
-
-} // namespace
-
-Status SessionRef::CheckNotClosed() {
- mutex_lock l(run_lock_);
- if (session_ == nullptr) return errors::Cancelled("Session has been closed.");
- return ::tensorflow::Status::OK();
-}
-
-Status SessionRef::Run(const RunOptions& run_options,
- const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_tensor_names,
- const std::vector<string>& target_node_names,
- std::vector<Tensor>* outputs,
- RunMetadata* run_metadata) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Run(run_options, inputs, output_tensor_names,
- target_node_names, outputs, run_metadata);
-}
-
-Status SessionRef::Create(const GraphDef& graph) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Create(graph);
-}
-
-Status SessionRef::Create(const RunOptions& run_options,
- const GraphDef& graph) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Create(run_options, graph);
-}
-
-Status SessionRef::Extend(const RunOptions& run_options,
- const GraphDef& graph) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Extend(run_options, graph);
-}
-
-Status SessionRef::Extend(const GraphDef& graph) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Extend(graph);
-}
-
-Status SessionRef::Close(const RunOptions& run_options) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- mutex_lock l(run_lock_);
- Status status = session_->Close(run_options);
- session_.reset();
- while (run_count_ > 0) {
- run_finished_.wait(l);
- }
- return status;
-}
-
-Status SessionRef::Close() {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- mutex_lock l(run_lock_);
- Status status = session_->Close();
- session_.reset();
- while (run_count_ > 0) {
- run_finished_.wait(l);
- }
- return status;
-}
-
-Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_tensor_names,
- const std::vector<string>& target_node_names,
- std::vector<Tensor>* outputs) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Run(inputs, output_tensor_names, target_node_names,
- outputs);
-}
-
-Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->ListDevices(response);
-}
-
-Status SessionRef::PRunSetup(const std::vector<string>& input_names,
- const std::vector<string>& output_names,
- const std::vector<string>& target_nodes,
- string* handle) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->PRunSetup(input_names, output_names, target_nodes, handle);
-}
-
-Status SessionRef::PRun(const string& handle,
- const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_names,
- std::vector<Tensor>* outputs) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->PRun(handle, inputs, output_names, outputs);
-}
-
-Status SessionRef::MakeCallable(const CallableOptions& callable_options,
- CallableHandle* out_handle) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->MakeCallable(callable_options, out_handle);
-}
-
-Status SessionRef::RunCallable(CallableHandle handle,
- const std::vector<Tensor>& feed_tensors,
- std::vector<Tensor>* fetch_tensors,
- RunMetadata* run_metadata) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->RunCallable(handle, feed_tensors, fetch_tensors,
- run_metadata);
-}
-
-Status SessionRef::ReleaseCallable(CallableHandle handle) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->ReleaseCallable(handle);
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc
index 836cb8ed14..a70ab93d4a 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.cc
+++ b/tensorflow/core/common_runtime/step_stats_collector.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace {
@@ -40,46 +41,24 @@ struct AllocStats {
};
} // namespace
-NodeExecStatsWrapper::NodeExecStatsWrapper(const string& node_name)
- : NodeExecStatsWrapper(new NodeExecStats) {
- stats_->set_node_name(node_name);
-}
-NodeExecStatsWrapper::NodeExecStatsWrapper(NodeExecStats* stats)
- : stats_(stats) {}
-
-void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* v) {
- DCHECK(v);
- NodeOutput* no = stats_->add_output();
- no->set_slot(slot);
- v->FillDescription(no->mutable_tensor_description());
-}
-
-void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) {
- for (const auto& allocator_pair : ctx->wrapped_allocators()) {
- AddAllocation(allocator_pair.first, allocator_pair.second);
- }
- auto* ms = stats_->mutable_memory_stats();
- ms->set_temp_memory_size(ctx->temp_memory_allocated());
- for (const auto& alloc_id : ctx->persistent_alloc_ids()) {
- ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id);
- }
- ms->set_persistent_memory_size(ctx->persistent_memory_allocated());
+NodeExecStatsWrapper::NodeExecStatsWrapper(
+ const Node* node, StepStatsCollector* step_stats_collector)
+ : NodeExecStatsWrapper(MakeUnique<NodeExecStats>(), node,
+ step_stats_collector) {
+ stats_->set_node_name(node->name());
}
-void NodeExecStatsWrapper::SetReferencedTensors(
- const TensorReferenceVector& tensors) {
- // be careful not to increment the reference count on any tensor
- // while recording the information
- for (size_t i = 0; i < tensors.size(); ++i) {
- AllocationDescription* description = stats_->add_referenced_tensor();
- tensors.at(i).FillDescription(description);
- }
-}
-
-// TODO(tucker): merge with the DetailText function in session.cc
-// in a common location.
-bool NodeExecStatsWrapper::SetTimelineLabel(const Node* node) {
- bool is_transfer_node = false;
+NodeExecStatsWrapper::NodeExecStatsWrapper(
+ std::unique_ptr<NodeExecStats> stats, const Node* node,
+ StepStatsCollector* step_stats_collector)
+ : stats_(std::move(stats)),
+ node_(node),
+ step_stats_collector_(step_stats_collector) {}
+
+void NodeExecStatsWrapper::Done(const string& device) {
+ // TODO(tucker): merge with the DetailText function in session.cc in a common
+ // location.
+ DCHECK(node_);
string memory;
for (auto& all : stats_->memory()) {
int64 tot = all.total_bytes();
@@ -96,31 +75,96 @@ bool NodeExecStatsWrapper::SetTimelineLabel(const Node* node) {
}
}
}
- const AttrSlice attrs = node->attrs();
+ const AttrSlice attrs = node_->attrs();
string text;
- if (IsSend(node)) {
+ if (IsSend(node_)) {
string tensor_name;
TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
string recv_device;
TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device));
- text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
+ text = strings::StrCat(memory, node_->name(), " = ", node_->type_string(),
"(", tensor_name, " @", recv_device);
- is_transfer_node = true;
- } else if (IsRecv(node)) {
+ } else if (IsRecv(node_)) {
string tensor_name;
TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
string send_device;
TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device));
- text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
+ text = strings::StrCat(memory, node_->name(), " = ", node_->type_string(),
"(", tensor_name, " @", send_device);
- is_transfer_node = true;
} else {
text =
- strings::StrCat(memory, node->name(), " = ", node->type_string(), "(",
- str_util::Join(node->requested_inputs(), ", "), ")");
+ strings::StrCat(memory, node_->name(), " = ", node_->type_string(), "(",
+ str_util::Join(node_->requested_inputs(), ", "), ")");
}
stats_->set_timeline_label(text);
- return is_transfer_node;
+ step_stats_collector_->Save(device, this);
+}
+
+void NodeExecStatsWrapper::RecordExecutorStarted() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
+ stats_->set_all_start_nanos(now_nanos);
+}
+
+void NodeExecStatsWrapper::RecordComputeStarted() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos());
+}
+
+void NodeExecStatsWrapper::RecordComputeEnded() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos());
+}
+
+void NodeExecStatsWrapper::RecordExecutorEnded() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos());
+}
+
+void NodeExecStatsWrapper::SetScheduled(int64 nanos) {
+ stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos);
+ stats_->set_scheduled_nanos(nanos);
+}
+
+void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) {
+ for (const auto& allocator_pair : ctx->wrapped_allocators()) {
+ AddAllocation(allocator_pair.first, allocator_pair.second);
+ }
+ auto* ms = stats_->mutable_memory_stats();
+ ms->set_temp_memory_size(ctx->temp_memory_allocated());
+ for (const auto& alloc_id : ctx->persistent_alloc_ids()) {
+ ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id);
+ }
+ ms->set_persistent_memory_size(ctx->persistent_memory_allocated());
+}
+
+void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* tensor) {
+ DCHECK(tensor);
+ NodeOutput* node_output = stats_->add_output();
+ node_output->set_slot(slot);
+ tensor->FillDescription(node_output->mutable_tensor_description());
+}
+
+void NodeExecStatsWrapper::SetReferencedTensors(
+ const TensorReferenceVector& tensors) {
+ // be careful not to increment the reference count on any tensor
+ // while recording the information
+ for (size_t i = 0; i < tensors.size(); ++i) {
+ AllocationDescription* description = stats_->add_referenced_tensor();
+ tensors.at(i).FillDescription(description);
+ }
}
void NodeExecStatsWrapper::AddAllocation(
@@ -150,8 +194,8 @@ void NodeExecStatsWrapper::Finalize() {
allocations_.clear();
}
-StepStatsCollector::StepStatsCollector(StepStats* ss)
- : finalized_(false), step_stats_(ss) {}
+StepStatsCollector::StepStatsCollector(StepStats* step_stats)
+ : finalized_(false), step_stats_(step_stats) {}
static int ExtractGpuWithStreamAll(string device_name) {
// Check if the device name matches the ".*gpu:(\\d+)/stream:all$" regexp,
@@ -338,28 +382,40 @@ void StepStatsCollector::BuildCostModel(
}
}
-void StepStatsCollector::Save(const string& device, NodeExecStats* nt) {
- Save(device, new NodeExecStatsWrapper(nt));
+void StepStatsCollector::Save(const string& device,
+ NodeExecStats* node_stats_pb) {
+ Save(device,
+ new NodeExecStatsWrapper(std::unique_ptr<NodeExecStats>(node_stats_pb),
+ nullptr, this));
}
void StepStatsCollector::Save(const string& device,
- NodeExecStatsWrapper* stats) {
- if (!stats) return;
- VLOG(1) << "Save dev " << device << " nt " << stats->stats();
+ NodeExecStatsWrapper* node_stats) {
+ if (!node_stats) return;
+ VLOG(1) << "Save dev " << device << " node stats " << node_stats->stats();
{
mutex_lock l(mu_);
if (finalized_) {
LOG(WARNING) << "stats saved after finalize will not be collected.";
}
- if (!step_stats_ || collectedNodes >= kMaxCollectedNodes) {
+ if (!step_stats_ || collected_nodes_ >= kMaxCollectedNodes) {
VLOG(1) << "step_stats_ nullptr or already collected too many nodes.";
- delete stats;
+ delete node_stats;
return;
}
- auto& dss = dev_stats_[device];
- dss.push_back(std::unique_ptr<NodeExecStatsWrapper>(stats));
- collectedNodes++;
+ auto& device_stats = dev_stats_[device];
+ device_stats.push_back(std::unique_ptr<NodeExecStatsWrapper>(node_stats));
+ collected_nodes_++;
+ }
+}
+
+NodeExecStatsInterface* StepStatsCollector::CreateNodeExecStats(
+ const Node* node) {
+ // Only collect statistics for non-transfer nodes.
+ if (IsSend(node) || IsRecv(node)) {
+ return nullptr;
}
+ return new NodeExecStatsWrapper(node, this);
}
string StepStatsCollector::ReportAllocsOnResourceExhausted(const string& err) {
@@ -446,12 +502,12 @@ void StepStatsCollector::Finalize() {
FinalizeInternal();
}
-void StepStatsCollector::FinalizeAndSwap(StepStats* ss) {
+void StepStatsCollector::FinalizeAndSwap(StepStats* step_stats) {
mutex_lock l(mu_);
CHECK(step_stats_);
FinalizeInternal();
- ss->Swap(step_stats_);
- collectedNodes = 0;
+ step_stats->Swap(step_stats_);
+ collected_nodes_ = 0;
}
void StepStatsCollector::FinalizeInternal() {
diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h
index 7206fbf427..4365b11b19 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.h
+++ b/tensorflow/core/common_runtime/step_stats_collector.h
@@ -36,81 +36,78 @@ class Node;
class NodeExecStats;
class OpKernelContext;
class StepStats;
+class StepStatsCollector;
class Tensor;
class TrackingAllocator;
-// Wraps NodeExecStats and adds allocation to it.
-class NodeExecStatsWrapper {
+// Statistics collection interface for individual node execution.
+//
+// See `NodeExecStatsWrapper` for a concrete implementation of this interface
+// that interfaces with the `Session` layer.
+class NodeExecStatsInterface {
public:
- NodeExecStatsWrapper(const string& node_name);
- // Owns 'stats'.
- NodeExecStatsWrapper(NodeExecStats* stats);
+ virtual ~NodeExecStatsInterface() {}
- // Destructor calls Finalize() to release the TrackingAllocators.
- ~NodeExecStatsWrapper() { Finalize(); }
-
- // Records the absolute time in nanoseconds at which this node became
- // runnable (i.e. was scheduled for execution).
- void SetScheduled(int64 nanos) {
- stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos);
- stats_->set_scheduled_nanos(nanos);
- }
+ // Called when the statistics collection for the node has finished. Once this
+ // method is called, the caller should not make assumptions about the validity
+ // of this object.
+ virtual void Done(const string& device) = 0;
// Called immediately after this node starts being processed by the executor.
- void RecordExecutorStarted() {
- int64 now_nanos = Env::Default()->NowNanos();
- stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
- stats_->set_all_start_nanos(now_nanos);
- }
+ virtual void RecordExecutorStarted() = 0;
// Called immediately before this node's `Compute()` or `ComputeAsync()`
// method is called.
- void RecordComputeStarted() {
- int64 now_nanos = Env::Default()->NowNanos();
- DCHECK_NE(stats_->all_start_micros(), 0);
- DCHECK_NE(stats_->all_start_nanos(), 0);
- stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- stats_->all_start_micros());
- stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos());
- }
+ virtual void RecordComputeStarted() = 0;
// Called immediately after this node's `Compute()` method returned (or, for
// asynchronous operations, the callback passed to its `ComputeAsync()` method
// was called).
- void RecordComputeEnded() {
- int64 now_nanos = Env::Default()->NowNanos();
- DCHECK_NE(stats_->all_start_micros(), 0);
- DCHECK_NE(stats_->all_start_nanos(), 0);
- stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- stats_->all_start_micros());
- stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos());
- }
+ virtual void RecordComputeEnded() = 0;
// Called immediately after this executor finishes processing this node.
- void RecordExecutorEnded() {
- int64 now_nanos = Env::Default()->NowNanos();
- DCHECK_NE(stats_->all_start_micros(), 0);
- DCHECK_NE(stats_->all_start_nanos(), 0);
- stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- stats_->all_start_micros());
- stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos());
- }
-
- // Records information about the tensor produced by this node at the given
- // output slot.
- void SetOutput(int slot, const Tensor* v);
+ virtual void RecordExecutorEnded() = 0;
// Records information about the memory allocated during the execution of this
// node.
- void SetMemory(OpKernelContext* ctx);
+ virtual void SetMemory(OpKernelContext* ctx) = 0;
+
+ // Records information about the tensor produced by this node at the given
+ // output slot.
+ virtual void SetOutput(int slot, const Tensor* tensor) = 0;
// Records information about the tensors that were accessed during the
// execution of this node.
- void SetReferencedTensors(const TensorReferenceVector& tensors);
+ virtual void SetReferencedTensors(const TensorReferenceVector& tensors) = 0;
- // Sets the timeline_label field of the wrapped NodeExecStats, using data
- // from *node. Returns true iff the node is a transfer node.
- bool SetTimelineLabel(const Node* node);
+ // Records the absolute time in nanoseconds at which this node became
+ // runnable (i.e. was scheduled for execution).
+ virtual void SetScheduled(int64 nanos) = 0;
+};
+
+// Wraps NodeExecStats and adds allocation to it.
+class NodeExecStatsWrapper : public NodeExecStatsInterface {
+ public:
+ // Does not take ownership of `node` or `step_stats_collector`.
+ NodeExecStatsWrapper(const Node* node,
+ StepStatsCollector* step_stats_collector);
+
+ // Takes ownership of 'stats' but not `node` or `step_stats_collector`.
+ NodeExecStatsWrapper(std::unique_ptr<NodeExecStats> stats, const Node* node,
+ StepStatsCollector* step_stats_collector);
+
+ // Destructor calls Finalize() to release the TrackingAllocators.
+ ~NodeExecStatsWrapper() { Finalize(); }
+
+ void Done(const string& device) override;
+ void RecordExecutorStarted() override;
+ void RecordComputeStarted() override;
+ void RecordComputeEnded() override;
+ void RecordExecutorEnded() override;
+ void SetMemory(OpKernelContext* ctx) override;
+ void SetOutput(int slot, const Tensor* tensor) override;
+ void SetReferencedTensors(const TensorReferenceVector& tensors) override;
+ void SetScheduled(int64 nanos) override;
private:
friend class StepStatsCollector;
@@ -128,9 +125,11 @@ class NodeExecStatsWrapper {
gtl::InlinedVector<std::pair<AllocatorMemoryUsed*, TrackingAllocator*>, 2>
allocations_;
std::unique_ptr<NodeExecStats> stats_;
+ const Node* const node_; // Not owned.
+ StepStatsCollector* const step_stats_collector_; // Not owned.
};
-// Statistics collection interface for individual node execution.
+// Statistics collection interface for step execution.
//
// See `StepStatsCollector` for a concrete implementation of this interface
// that interfaces with the `Session` layer.
@@ -138,8 +137,9 @@ class StepStatsCollectorInterface {
public:
virtual ~StepStatsCollectorInterface() {}
- // Saves `stats` to the collector.
- virtual void Save(const string& device, NodeExecStatsWrapper* stats) = 0;
+ // Creates an instance of `NodeExecStatsInterface` that should be used for
+ // collecting statistics about individual node execution.
+ virtual NodeExecStatsInterface* CreateNodeExecStats(const Node* node) = 0;
// Generates a string reporting the currently used memory based
// on ResourceExhausted OOM `err` message.
@@ -154,8 +154,8 @@ class StepStatsCollectorInterface {
// Each DeviceStats object holds multiple NodeExecStats.
class StepStatsCollector : public StepStatsCollectorInterface {
public:
- // Does not take ownership of `ss`.
- explicit StepStatsCollector(StepStats* ss);
+ // Does not take ownership of `step_stats`.
+ explicit StepStatsCollector(StepStats* step_stats);
// BuildCostModel builds or updates a CostModel managed by cost_model_manager,
// using the currently collected DeviceStats associated with the devices in
@@ -164,11 +164,12 @@ class StepStatsCollector : public StepStatsCollectorInterface {
CostModelManager* cost_model_manager,
const std::unordered_map<string, const Graph*>& device_map);
- // Save saves nt to the DeviceStats object associated with device.
+ // Saves node statistics to the DeviceStats object associated with device.
// Should be called before Finalize.
- void Save(const string& device, NodeExecStats* nt);
- void Save(const string& device, NodeExecStatsWrapper* stats) override;
+ void Save(const string& device, NodeExecStats* node_stats_pb);
+ void Save(const string& device, NodeExecStatsWrapper* node_stats);
+ NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override;
string ReportAllocsOnResourceExhausted(const string& err) override;
// The following 2 Finalize methods populate the StepStats passed
@@ -176,20 +177,22 @@ class StepStatsCollector : public StepStatsCollectorInterface {
// User shouldn't call Save() methods after Finalize.
void Finalize();
// swaps the content of StepStats* from constructor with 'ss'.
- void FinalizeAndSwap(StepStats* ss);
+ void FinalizeAndSwap(StepStats* step_stats);
private:
+ // TODO(suharshs): Make this configurable if its not possible to find a value
+ // that works for all cases.
+ static const uint64 kMaxCollectedNodes = 1 << 20;
+
+ typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeStatsVector;
+
void FinalizeInternal() EXCLUSIVE_LOCKS_REQUIRED(mu_);
- typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeExecStatsVec;
- // TODO(suharshs): Make this configurable if its not possible to find a value
- // that works for all cases.
- const uint64 kMaxCollectedNodes = 1 << 20;
mutex mu_;
bool finalized_ GUARDED_BY(mu_);
- std::unordered_map<string, NodeExecStatsVec> dev_stats_ GUARDED_BY(mu_);
+ std::unordered_map<string, NodeStatsVector> dev_stats_ GUARDED_BY(mu_);
StepStats* step_stats_ GUARDED_BY(mu_);
- uint64 collectedNodes GUARDED_BY(mu_) = 0;
+ uint64 collected_nodes_ GUARDED_BY(mu_) = 0;
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc
index 0fbc20b34b..8587d1783a 100644
--- a/tensorflow/core/common_runtime/threadpool_device.cc
+++ b/tensorflow/core/common_runtime/threadpool_device.cc
@@ -113,8 +113,11 @@ class MklCPUAllocatorFactory : public AllocatorFactory {
}
};
+#ifdef ENABLE_MKL
REGISTER_MEM_ALLOCATOR("MklCPUAllocator", 200, MklCPUAllocatorFactory);
+#endif // ENABLE_MKL
+
} // namespace
-#endif
+#endif // INTEL_MKL
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/tracing_device.h b/tensorflow/core/common_runtime/tracing_device.h
deleted file mode 100644
index e1b163074f..0000000000
--- a/tensorflow/core/common_runtime/tracing_device.h
+++ /dev/null
@@ -1,60 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_TRACING_DEVICE_H_
-#define TENSORFLOW_CORE_COMMON_RUNTIME_TRACING_DEVICE_H_
-
-#include "tensorflow/core/common_runtime/device.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/tracing.h"
-
-namespace tensorflow {
-
-namespace test {
-class Benchmark;
-}
-struct SessionOptions;
-
-// This class implements tracing functionality that is shared by its subclasses
-// (including ThreadPoolDevice and XlaDevice).
-class TracingDevice : public Device {
- public:
- TracingDevice(Env* env, const DeviceAttributes& attributes)
- : Device(env, attributes) {}
-
- void Compute(OpKernel* op_kernel, OpKernelContext* context) override {
- const tracing::TraceCollector* trace_collector =
- tracing::GetTraceCollector();
- if (TF_PREDICT_FALSE(
- (trace_collector &&
- trace_collector->IsEnabled(op_kernel->IsExpensive())) ||
- tracing::GetEventCollector(tracing::EventCategory::kCompute))) {
- const string& op_name = op_kernel->name();
- tracing::ScopedActivity activity(op_name, op_kernel->type_string(),
- op_kernel->IsExpensive());
- tracing::ScopedRegion region(tracing::EventCategory::kCompute, op_name);
- op_kernel->Compute(context);
- } else {
- op_kernel->Compute(context);
- }
- }
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(TracingDevice);
-};
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_CORE_COMMON_RUNTIME_TRACING_DEVICE_H_
diff --git a/tensorflow/core/common_runtime/visitable_allocator.h b/tensorflow/core/common_runtime/visitable_allocator.h
deleted file mode 100644
index ae0563a96a..0000000000
--- a/tensorflow/core/common_runtime/visitable_allocator.h
+++ /dev/null
@@ -1,79 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
-#define TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
-
-#include <functional>
-#include "tensorflow/core/framework/allocator.h"
-#include "tensorflow/core/framework/tracking_allocator.h"
-
-namespace tensorflow {
-
-// Subclass VisitableAllocator instead of Allocator when a memory
-// allocator needs to enable some kind of registration/deregistration
-// of memory areas.
-class VisitableAllocator : public Allocator {
- public:
- // Visitor gets called with a pointer to a memory area and its
- // size in bytes.
- typedef std::function<void(void*, size_t)> Visitor;
-
- // Register a visitor guaranteed to be called exactly once on each
- // chunk of memory newly allocated from the underlying device.
- // Typically, chunks will be reused and possibly sub-divided by a
- // pool manager, so the calls will happen only once per process
- // execution, not once per tensor (re)allocation.
- virtual void AddAllocVisitor(Visitor visitor) = 0;
-
- // Register a visitor guaranteed to be called on each chunk of
- // memory returned to the underlying device.
- virtual void AddFreeVisitor(Visitor visitor) = 0;
-};
-
-// Needed for cases when a VisitableAllocator gets wrapped for tracking.
-// Multiple-inheritance is considered acceptable in this case because
-// VisitableAllocator is a pure virtual interface and only TrackingAllocator
-// has default implementation.
-class TrackingVisitableAllocator : public TrackingAllocator,
- public VisitableAllocator {
- public:
- TrackingVisitableAllocator(VisitableAllocator* allocator, bool track_ids)
- : TrackingAllocator(allocator, track_ids), allocator_(allocator) {}
- ~TrackingVisitableAllocator() override {}
-
- string Name() override { return TrackingAllocator::Name(); }
-
- void* AllocateRaw(size_t alignment, size_t num_bytes) override {
- return TrackingAllocator::AllocateRaw(alignment, num_bytes);
- }
-
- void DeallocateRaw(void* ptr) override {
- TrackingAllocator::DeallocateRaw(ptr);
- }
-
- void AddAllocVisitor(Visitor visitor) override {
- allocator_->AddAllocVisitor(visitor);
- }
-
- void AddFreeVisitor(Visitor visitor) override {
- allocator_->AddFreeVisitor(visitor);
- }
-
- protected:
- VisitableAllocator* allocator_;
-};
-} // namespace tensorflow
-#endif // TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index f7a2967d00..3361819e43 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -475,10 +475,7 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
delete step_container;
});
Executor::Args args;
- {
- mutex_lock l(mu_);
- args.step_id = ++next_id_;
- }
+ args.step_id = step_id;
args.rendezvous = rendezvous;
args.collective_executor = ce_handle ? ce_handle->get() : nullptr;
args.cancellation_manager = cancellation_manager;
diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h
index ec93b9aad9..016d1a92c1 100644
--- a/tensorflow/core/example/feature_util.h
+++ b/tensorflow/core/example/feature_util.h
@@ -103,6 +103,7 @@ limitations under the License.
#include <iterator>
#include <type_traits>
+#include "absl/base/macros.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -113,10 +114,10 @@ namespace tensorflow {
namespace internal {
-// DEPRECATED: Use GetFeature instead.
// TODO(gorban): Update all clients in a followup CL.
// Returns a reference to a feature corresponding to the name.
// Note: it will create a new Feature if it is missing in the example.
+ABSL_DEPRECATED("Use GetFeature instead.")
Feature& ExampleFeature(const string& name, Example* example);
// Specializations of RepeatedFieldTrait define a type of RepeatedField
@@ -314,9 +315,9 @@ bool HasFeature(const string& key, const Example& example) {
return HasFeature<FeatureType...>(key, GetFeatures(example));
}
-// DEPRECATED: use HasFeature instead.
// TODO(gorban): update all clients in a followup CL.
template <typename... FeatureType>
+ABSL_DEPRECATED("Use HasFeature instead.")
bool ExampleHasFeature(const string& key, const Example& example) {
return HasFeature<FeatureType...>(key, example);
}
diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc
index 2a7ee16a16..84cee5569c 100644
--- a/tensorflow/core/framework/allocator.cc
+++ b/tensorflow/core/framework/allocator.cc
@@ -196,7 +196,7 @@ class CPUAllocatorFactory : public AllocatorFactory {
class CPUSubAllocator : public SubAllocator {
public:
explicit CPUSubAllocator(CPUAllocator* cpu_allocator)
- : cpu_allocator_(cpu_allocator) {}
+ : SubAllocator({}, {}), cpu_allocator_(cpu_allocator) {}
void* Alloc(size_t alignment, size_t num_bytes) override {
return cpu_allocator_->AllocateRaw(alignment, num_bytes);
@@ -222,4 +222,22 @@ Allocator* cpu_allocator() {
}
return cpu_alloc;
}
+
+SubAllocator::SubAllocator(const std::vector<Visitor>& alloc_visitors,
+ const std::vector<Visitor>& free_visitors)
+ : alloc_visitors_(alloc_visitors), free_visitors_(free_visitors) {}
+
+void SubAllocator::VisitAlloc(void* ptr, int index, size_t num_bytes) {
+ for (const auto& v : alloc_visitors_) {
+ v(ptr, index, num_bytes);
+ }
+}
+
+void SubAllocator::VisitFree(void* ptr, int index, size_t num_bytes) {
+ // Although we don't guarantee any order of visitor application, strive
+ // to apply free visitors in reverse order of alloc visitors.
+ for (int i = free_visitors_.size() - 1; i >= 0; --i) {
+ free_visitors_[i](ptr, index, num_bytes);
+ }
+}
} // namespace tensorflow
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
index ded120b704..8c23604625 100644
--- a/tensorflow/core/framework/allocator.h
+++ b/tensorflow/core/framework/allocator.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@@ -387,13 +388,36 @@ void EnableCPUAllocatorStats(bool enable);
// full statistics. By default, it's disabled.
void EnableCPUAllocatorFullStats(bool enable);
-// Abstract interface of an object that does the underlying suballoc/free of
-// memory for a higher-level allocator.
+// An object that does the underlying suballoc/free of memory for a higher-level
+// allocator. The expectation is that the higher-level allocator is doing some
+// kind of cache or pool management so that it will call SubAllocator::Alloc and
+// Free relatively infrequently, compared to the number of times its own
+// AllocateRaw and Free methods are called.
class SubAllocator {
public:
+ // Visitor gets called with a pointer to a memory area and its
+ // size in bytes. The index value will be numa_node for a CPU
+ // allocator and GPU id for a GPU allocator.
+ typedef std::function<void(void*, int index, size_t)> Visitor;
+
+ SubAllocator(const std::vector<Visitor>& alloc_visitors,
+ const std::vector<Visitor>& free_visitors);
+
virtual ~SubAllocator() {}
virtual void* Alloc(size_t alignment, size_t num_bytes) = 0;
virtual void Free(void* ptr, size_t num_bytes) = 0;
+
+ protected:
+ // Implementation of Alloc() method must call this on newly allocated
+ // value.
+ void VisitAlloc(void* ptr, int index, size_t num_bytes);
+
+ // Implementation of Free() method must call this on value to be
+ // freed immediately before deallocation.
+ void VisitFree(void* ptr, int index, size_t num_bytes);
+
+ const std::vector<Visitor> alloc_visitors_;
+ const std::vector<Visitor> free_visitors_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc
index 1258e40c93..af59500aee 100644
--- a/tensorflow/core/framework/cancellation.cc
+++ b/tensorflow/core/framework/cancellation.cc
@@ -89,6 +89,16 @@ bool CancellationManager::DeregisterCallback(CancellationToken token) {
}
}
+bool CancellationManager::TryDeregisterCallback(CancellationToken token) {
+ mutex_lock lock(mu_);
+ if (is_cancelled_ || is_cancelling_) {
+ return false;
+ } else {
+ callbacks_.erase(token);
+ return true;
+ }
+}
+
CancellationManager::~CancellationManager() {
if (!callbacks_.empty()) {
StartCancel();
diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h
index acdaaf6a90..7a5d942486 100644
--- a/tensorflow/core/framework/cancellation.h
+++ b/tensorflow/core/framework/cancellation.h
@@ -122,6 +122,15 @@ class CancellationManager {
// cancellation manager.
bool DeregisterCallback(CancellationToken token);
+ // Deregister the callback that, when registered, was associated
+ // with the given cancellation token. Returns true iff the callback
+ // was deregistered and will not be invoked; otherwise returns false
+ // immediately, with no guarantee that the callback has completed.
+ //
+ // This method is guaranteed to return true if StartCancel has not been
+ // called.
+ bool TryDeregisterCallback(CancellationToken token);
+
private:
bool is_cancelling_;
std::atomic_bool is_cancelled_;
diff --git a/tensorflow/core/framework/cancellation_test.cc b/tensorflow/core/framework/cancellation_test.cc
index e3f18240b5..bf7593bc5f 100644
--- a/tensorflow/core/framework/cancellation_test.cc
+++ b/tensorflow/core/framework/cancellation_test.cc
@@ -115,4 +115,56 @@ TEST(Cancellation, IsCancelled) {
delete cm;
}
+TEST(Cancellation, TryDeregisterWithoutCancel) {
+ bool is_cancelled = false;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(
+ token, [&is_cancelled]() { is_cancelled = true; });
+ EXPECT_TRUE(registered);
+ bool deregistered = manager->TryDeregisterCallback(token);
+ EXPECT_TRUE(deregistered);
+ delete manager;
+ EXPECT_FALSE(is_cancelled);
+}
+
+TEST(Cancellation, TryDeregisterAfterCancel) {
+ bool is_cancelled = false;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(
+ token, [&is_cancelled]() { is_cancelled = true; });
+ EXPECT_TRUE(registered);
+ manager->StartCancel();
+ EXPECT_TRUE(is_cancelled);
+ bool deregistered = manager->TryDeregisterCallback(token);
+ EXPECT_FALSE(deregistered);
+ delete manager;
+}
+
+TEST(Cancellation, TryDeregisterDuringCancel) {
+ Notification cancel_started, finish_callback, cancel_complete;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(token, [&]() {
+ cancel_started.Notify();
+ finish_callback.WaitForNotification();
+ });
+ EXPECT_TRUE(registered);
+
+ thread::ThreadPool w(Env::Default(), "test", 1);
+ w.Schedule([&]() {
+ manager->StartCancel();
+ cancel_complete.Notify();
+ });
+ cancel_started.WaitForNotification();
+
+ bool deregistered = manager->TryDeregisterCallback(token);
+ EXPECT_FALSE(deregistered);
+
+ finish_callback.Notify();
+ cancel_complete.WaitForNotification();
+ delete manager;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 20a07d86a2..50403b4004 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -1306,6 +1306,113 @@ Status RandomShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
+namespace {
+
+// This SliceHelper processes the output shape of the `slice`
+// when the tensor of `sizes` is available.
+template <typename T>
+Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
+ const Tensor* sizes_value,
+ std::vector<DimensionHandle>* dims) {
+ auto sizes_vec = sizes_value->vec<T>();
+ for (int i = 0; i < sizes_value->NumElements(); ++i) {
+ DimensionHandle dim = c->Dim(c->input(0), i);
+ if (sizes_vec(i) != -1) {
+ auto dim_val = c->Value(dim);
+ if (sizes_vec(i) < 0) {
+ return errors::InvalidArgument(
+ "Out of bounds slicing on dimension ", i, " of length ", dim_val,
+ ": sizes vector cannot be < -1, but was ", sizes_vec(i));
+ }
+
+ dims->emplace_back(c->MakeDim(sizes_vec(i)));
+ } else {
+ DimensionHandle result;
+ TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result));
+ dims->emplace_back(result);
+ }
+ }
+
+ return Status::OK();
+}
+} // namespace
+
+Status SliceShape(InferenceContext* c) {
+ ShapeHandle input = c->input(0);
+ ShapeHandle begin_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
+ ShapeHandle sizes_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape));
+
+ // Merge to check compatibility of begin and sizes tensors.
+ TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape));
+
+ DimensionHandle ndims = c->Dim(begin_shape, 0);
+ if (c->ValueKnown(ndims)) {
+ TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input));
+ }
+
+ // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known
+ // values, even though the `begin` value does not represent a shape.
+ ShapeHandle begin_value;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value));
+
+ // We check the tensor value here and will only use
+ // `MakeShapeFromShapeTensor` when `sizes_value` is null.
+ // The reason is that `sizes` might contain -1, which can't
+ // be represented (-1 in the ShapeHandle would mean "unknown").
+ const Tensor* sizes_value = c->input_tensor(2);
+
+ if (sizes_value != nullptr) {
+ TF_RETURN_IF_ERROR(
+ c->WithRank(begin_value, sizes_value->NumElements(), &begin_value));
+ std::vector<DimensionHandle> dims;
+ // If the begin and sizes tensors are available, then
+ // we can be precise about the shape of the output.
+ if (sizes_value->dtype() == DT_INT64) {
+ TF_RETURN_IF_ERROR(
+ SliceHelper<int64>(c, begin_value, sizes_value, &dims));
+ } else {
+ TF_RETURN_IF_ERROR(
+ SliceHelper<int32>(c, begin_value, sizes_value, &dims));
+ }
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+ } else {
+ // In case `sizes` is not available (`sizes_value` is null),
+ // we could try to use `MakeShapeFromShapeTensor` here.
+ // If sizes contain -1, we will simply consider it as `Unknown`.
+ // This is less than ideal but still an improvement of shape inference.
+ // The following is an example that returns [None, 1, None] with this
+ // code path:
+ // z = tf.zeros((1, 2, 3))
+ // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1])
+ // m.get_shape().as_list()
+ ShapeHandle sizes_value;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value));
+ if (c->RankKnown(sizes_value)) {
+ TF_RETURN_IF_ERROR(
+ c->WithRank(begin_value, c->Rank(sizes_value), &begin_value));
+ std::vector<DimensionHandle> dims;
+ dims.reserve(c->Rank(sizes_value));
+ for (int i = 0; i < c->Rank(sizes_value); ++i) {
+ dims.emplace_back(c->Dim(sizes_value, i));
+ }
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+ }
+ // We might know the rank of the input.
+ if (c->RankKnown(input)) {
+ c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
+ return Status::OK();
+ } else {
+ return shape_inference::UnknownShape(c);
+ }
+ }
+
+ return Status::OK();
+}
+
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
ShapeHandle values_shape, ShapeHandle shape_shape) {
// Validate ranks.
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index e6f9f935f9..3a496e06ae 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -293,6 +293,9 @@ inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
// Shape function for random operations.
Status RandomShape(shape_inference::InferenceContext* c);
+// Shape function for Slice opertaions.
+Status SliceShape(shape_inference::InferenceContext* c);
+
// Validates the 3 component tensors of a sparse tensor have the proper
// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc
index 5281c56f04..284dafb886 100644
--- a/tensorflow/core/framework/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -20,7 +20,6 @@ limitations under the License.
namespace tensorflow {
namespace data {
-
namespace {
// A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor.
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 4ee6749eea..697e0604bf 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -47,6 +47,8 @@ class GraphDefBuilder;
class Node;
namespace data {
+// A constant that can be used to enable auto-tuning.
+constexpr int kAutoTune = -1;
class DatasetBase;
class SerializationContext;
@@ -527,25 +529,11 @@ class DatasetBase : public core::RefCounted {
std::unique_ptr<IteratorBase>* iterator) const {
*iterator = MakeIteratorInternal(prefix);
if (ctx->model()) {
- // The prefix might contain an index. We need to strip it to make it
- // possible for the model to successfully identify the output node.
- string sanitized_prefix = prefix;
- if (str_util::EndsWith(prefix, "]")) {
- sanitized_prefix = prefix.substr(0, prefix.rfind('['));
- }
- std::shared_ptr<model::Node> node =
- ctx->model()->AddNode((*iterator)->prefix(), sanitized_prefix);
- std::vector<string> tokens =
- str_util::Split((*iterator)->prefix(), ':', str_util::SkipEmpty());
- node->set_name(tokens[tokens.size() - 1]);
+ ctx->model()->AddNode((*iterator)->prefix(), prefix);
std::shared_ptr<model::Model> model = ctx->model();
const string& prefix = (*iterator)->prefix();
- (*iterator)->AddCleanupFunction([model, node, prefix]() {
- if (node->output()) {
- node->output()->remove_input(node);
- }
- model->RemoveNode(prefix);
- });
+ (*iterator)->AddCleanupFunction(
+ [model, prefix]() { model->RemoveNode(prefix); });
}
return (*iterator)->Initialize(ctx);
}
@@ -627,23 +615,10 @@ class DatasetBaseIterator : public IteratorBase {
Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) final {
tracing::ScopedActivity activity(params_.prefix);
- Status s;
- if (ctx->model()) {
- std::shared_ptr<model::Node> node =
- ctx->model()->LookupNode(params_.prefix);
- if (node->output()) {
- node->output()->stop_work();
- }
- node->start_work();
- s = GetNextInternal(ctx, out_tensors, end_of_sequence);
- node->stop_work();
- node->add_element();
- if (node->output()) {
- node->output()->start_work();
- }
- } else {
- s = GetNextInternal(ctx, out_tensors, end_of_sequence);
- }
+ RecordStart(ctx, true /* stop_output */);
+ Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
+ if (s.ok() && !*end_of_sequence) RecordElement(ctx);
+ RecordStop(ctx, true /* start_output */);
if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) {
s = errors::Internal(
"Iterator \"", params_.prefix,
@@ -670,36 +645,51 @@ class DatasetBaseIterator : public IteratorBase {
return strings::StrCat(params_.prefix, ":", name);
}
- // When performance modeling is enabled, this method sets metadata entry for
- // the model node corresponding to this iterator.
- void SetMetadata(IteratorContext* ctx, const string& key, int64 value) {
+ // When performance modeling is enabled, this method adds a constant parameter
+ // to the model node corresponding to this iterator.
+ void AddConstantParameter(IteratorContext* ctx, const string& name,
+ int64 value) {
if (ctx->model()) {
- std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
- if (node) {
- node->set_metadata(key, value);
- }
+ ctx->model()->AddConstantParameter(prefix(), name, value);
+ }
+ }
+
+ // When performance modeling is enabled, this method adds a tunable parameter
+ // to the model node corresponding to this iterator.
+ //
+ // The performance modeling logic may use `value` to set the value of the
+ // tunable parameter at any point during the lifetime of this iterator. When
+ // it does, it notifies `cond_var`.
+ void AddTunableParameter(IteratorContext* ctx, const string& name,
+ std::atomic<int64>* value, int64 min, int64 max,
+ condition_variable* cond_var) {
+ if (ctx->model()) {
+ ctx->model()->AddTunableParameter(prefix(), name, value, min, max,
+ cond_var);
+ }
+ }
+
+ // When performance modeling is enabled, this method records the fact that
+ // this iterator has produced an element.
+ void RecordElement(IteratorContext* ctx) {
+ if (ctx->model()) {
+ ctx->model()->RecordElement(prefix());
}
}
// When performance modeling is enabled, this method records the fact that
// a thread of this iterator has started work.
- void StartWork(IteratorContext* ctx) {
+ void RecordStart(IteratorContext* ctx, bool stop_output = false) {
if (ctx->model()) {
- std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
- if (node) {
- node->start_work();
- }
+ ctx->model()->RecordStart(prefix(), stop_output);
}
}
// When performance modeling is enabled, this method records the fact that
// a thread of this iterator has stopped work.
- void StopWork(IteratorContext* ctx) {
+ void RecordStop(IteratorContext* ctx, bool start_output = false) {
if (ctx->model()) {
- std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
- if (node) {
- node->stop_work();
- }
+ ctx->model()->RecordStop(prefix(), start_output);
}
}
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index 794250a2c1..446c31b17f 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/base/macros.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
@@ -176,9 +177,9 @@ class DeviceBase {
return nullptr;
}
- // DEPRECATED: Use `this->GetAllocator()` or `this->GetScopedAllocator()`.
// This method is provided for backwards compatibility, and will be removed
// in a future release.
+ ABSL_DEPRECATED("Use `this->GetAllocator()` or `this->GetScopedAllocator()`.")
Allocator* GetStepAllocator(AllocatorAttributes attr, ResourceMgr*) {
return GetAllocator(attr);
}
@@ -214,10 +215,12 @@ class DeviceBase {
// This is overridden by GPU devices to reinitialize the derived
// type returned by MakeGpuDevice.
- virtual void ReinitializeGpuDevice(OpKernelContext* /*context*/,
- PerOpGpuDevice* /*device*/,
- DeviceContext* /*dc*/,
- Allocator* /*allocator*/) {}
+ virtual Status ReinitializeGpuDevice(OpKernelContext* /*context*/,
+ PerOpGpuDevice* /*device*/,
+ DeviceContext* /*dc*/,
+ Allocator* /*allocator*/) {
+ return Status::OK();
+ }
// Unimplemented by default
virtual const DeviceAttributes& attributes() const;
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc
index d5c203d276..0445c242e9 100644
--- a/tensorflow/core/framework/function_testlib.cc
+++ b/tensorflow/core/framework/function_testlib.cc
@@ -93,7 +93,6 @@ FunctionDef IsZero() {
FunctionDef RandomUniform() {
const Tensor kZero = test::AsScalar<int64>(0);
- const Tensor kTen = test::AsScalar<int64>(10);
return FDH::Define(
// Name
@@ -108,19 +107,11 @@ FunctionDef RandomUniform() {
"Const",
{},
{{"value", kZero}, {"dtype", DT_INT64}}},
- {{"random_uniform/min"},
- "Const",
- {},
- {{"value", kZero}, {"dtype", DT_INT64}}},
- {{"random_uniform/max"},
- "Const",
- {},
- {{"value", kTen}, {"dtype", DT_INT64}}},
{{"random_uniform"},
- "RandomUniformInt",
- {},
- {{"T", DT_INT64},
- {"Tout", DT_INT64},
+ "RandomUniform",
+ {"random_uniform/shape"},
+ {{"T", DT_INT32},
+ {"Tout", DT_FLOAT},
{"seed", 87654321},
{"seed2", 42}}}});
}
diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc
index 250b006641..b0330ec990 100644
--- a/tensorflow/core/framework/model.cc
+++ b/tensorflow/core/framework/model.cc
@@ -1,4 +1,4 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -15,52 +15,26 @@ limitations under the License.
#include "tensorflow/core/framework/model.h"
+#include <memory>
+
namespace tensorflow {
namespace data {
namespace model {
// TODO(jsimsa): Use `Node` subclassing instead of types and node statements.
-void Node::CollectKnobs(std::vector<Node::Knob>* knobs) {
- mutex_lock l(mu_);
+void Model::Node::CollectTunables(
+ std::vector<std::shared_ptr<Node::Tunable>>* tunables) {
+ tf_shared_lock l(mu_);
+ for (auto input : inputs_) {
+ input->CollectTunables(tunables);
+ }
switch (type_) {
- case Type::PARALLEL_INTERLEAVE_V2: {
- for (auto input : inputs_) {
- input->CollectKnobs(knobs);
- }
- int64 processing_time = static_cast<int64>(
- static_cast<double>(ProcessingTimeLocked() -
- inputs_.front()->ProcessingTime()) /
- static_cast<double>(inputs_.size() - 1));
- knobs->emplace_back(
- Node::Knob{this, processing_time, metadata_["parallelism"]});
- return;
- }
case Type::MAP_AND_BATCH:
+ case Type::PARALLEL_INTERLEAVE_V2:
case Type::PARALLEL_MAP: {
- for (auto input : inputs_) {
- input->CollectKnobs(knobs);
- }
- knobs->emplace_back(
- Node::Knob{this, NanosPerElementLocked(), metadata_["parallelism"]});
- return;
- }
- case Type::BATCH:
- case Type::CACHE:
- case Type::CONCATENATE:
- case Type::FILTER:
- case Type::FLAT_MAP:
- case Type::INTERLEAVE:
- case Type::MAP:
- case Type::PADDED_BATCH:
- case Type::PARALLEL_INTERLEAVE:
- case Type::PREFETCH:
- case Type::REPEAT:
- case Type::SHUFFLE:
- case Type::SKIP:
- case Type::TAKE:
- case Type::ZIP: {
- for (auto input : inputs_) {
- input->CollectKnobs(knobs);
+ if (auto* tunable_param =
+ gtl::FindOrNull(tunable_params_, "parallelism")) {
+ tunables->push_back(*tunable_param);
}
return;
}
@@ -69,12 +43,19 @@ void Node::CollectKnobs(std::vector<Node::Knob>* knobs) {
}
}
-int64 Node::ProcessingTimeLocked() {
+int64 Model::Node::GetParameterValue(const string& name) {
+ if (auto* tunable_param = gtl::FindOrNull(tunable_params_, name)) {
+ return (*tunable_param)->value;
+ }
+ return constant_params_[name];
+}
+
+int64 Model::Node::ProcessingTimeLocked() {
switch (type_) {
case Type::BATCH:
case Type::MAP_AND_BATCH:
case Type::PADDED_BATCH: {
- int64 batch_size = metadata_["batch_size"];
+ int64 batch_size = GetParameterValue("batch_size");
return NanosPerElementLocked() + batch_size * ProcessingTimeForInputs();
}
case Type::FILTER: {
@@ -118,11 +99,11 @@ int64 Node::ProcessingTimeLocked() {
}
}
-int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
+int64 Model::Node::OutputTimeLocked(std::vector<int64>* input_times) {
switch (type_) {
case Type::BATCH:
case Type::PADDED_BATCH: {
- double batch_size = metadata_["batch_size"];
+ double batch_size = GetParameterValue("batch_size");
int64 old_value = (*input_times)[input_times->size() - 1];
(*input_times)[input_times->size() - 1] = static_cast<int64>(
static_cast<double>(old_value + NanosPerElementLocked()) /
@@ -168,8 +149,8 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
static_cast<double>(inputs_.size() - 1);
}
case Type::MAP_AND_BATCH: {
- double batch_size = metadata_["batch_size"];
- double parallelism = metadata_["parallelism"];
+ double batch_size = GetParameterValue("batch_size");
+ double parallelism = GetParameterValue("parallelism");
int64 delta =
static_cast<int64>(static_cast<double>(NanosPerElementLocked()) /
(batch_size * parallelism));
@@ -182,22 +163,41 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
return std::max(0LL,
output_time - input_times->at(input_times->size() - 2));
}
- case Type::PARALLEL_INTERLEAVE:
+ case Type::PARALLEL_INTERLEAVE: {
+ // TODO(jsimsa): model the first input
+ if (inputs_.size() <= 1) {
+ return NanosPerElementLocked();
+ }
+ int64 delta = static_cast<double>(NanosPerElementLocked()) *
+ static_cast<double>(inputs_.size() - 1);
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ int64 inputs_output_time = OutputTimeForInputs(input_times) -
+ inputs_.front()->OutputTime(input_times);
+ double parallelism = GetParameterValue("parallelism");
+ int64 output_time =
+ NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) /
+ static_cast<double>(inputs_.size() - 1)) /
+ parallelism);
+ return std::max(0LL,
+ output_time - input_times->at(input_times->size() - 2));
+ }
case Type::PARALLEL_INTERLEAVE_V2: {
// TODO(jsimsa): model the first input
if (inputs_.size() <= 1) {
return NanosPerElementLocked();
}
- int64 delta =
- static_cast<int64>(static_cast<double>(NanosPerElementLocked()) *
- static_cast<double>(inputs_.size() - 1));
+ int64 delta = static_cast<double>(NanosPerElementLocked()) *
+ static_cast<double>(inputs_.size() - 1);
input_times->push_back(delta);
auto cleanup =
gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
int64 inputs_output_time = OutputTimeForInputs(input_times) -
inputs_.front()->OutputTime(input_times);
- double parallelism = std::min(port::NumSchedulableCPUs(),
- static_cast<int>(metadata_["parallelism"]));
+ double parallelism =
+ std::min(static_cast<int>(GetParameterValue("cycle_length")),
+ static_cast<int>(GetParameterValue("parallelism")));
int64 output_time =
NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) /
static_cast<double>(inputs_.size() - 1)) /
@@ -206,8 +206,9 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
output_time - input_times->at(input_times->size() - 2));
}
case Type::PARALLEL_MAP: {
- double parallelism = std::min(port::NumSchedulableCPUs(),
- static_cast<int>(metadata_["parallelism"]));
+ double parallelism =
+ std::min(port::NumSchedulableCPUs(),
+ static_cast<int>(GetParameterValue("parallelism")));
int64 delta = static_cast<int64>(
static_cast<double>(NanosPerElementLocked()) / parallelism);
input_times->push_back(delta);
@@ -248,32 +249,34 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
}
}
-Model::Model(const proto::Model& model_proto) {
- id_counter_ = model_proto.id_counter();
- std::map<int64, std::shared_ptr<Node>> lookup_table;
- for (auto node_proto : model_proto.node()) {
- std::shared_ptr<Node> node(new Node(node_proto));
- lookup_table[node_proto.id()] = node;
- }
- for (auto node_proto : model_proto.node()) {
- std::shared_ptr<Node> node = lookup_table[node_proto.id()];
- for (int64 id : node_proto.input()) {
- node->add_input(lookup_table[id]);
- }
- node->set_output(lookup_table[node_proto.output()]);
+void Model::AddConstantParameter(const string& node_name,
+ const string& parameter_name, int64 value) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, node_name);
+ if (node) {
+ (*node)->add_constant_param(parameter_name, value);
}
- output_ = lookup_table[model_proto.output()];
}
-std::shared_ptr<Node> Model::AddNode(const string& name,
- const string& output_name) {
- mutex_lock l(mu_);
+void Model::AddNode(const string& name, const string& output_name) {
+ // The name captures the sequence of iterators joined by `::`. We use the full
+ // sequence as the key in the lookup table, but only the last element of the
+ // sequence as the name node.
+ std::vector<string> tokens =
+ str_util::Split(name, ':', str_util::SkipEmpty());
+ // The output name might contain an index. We need to strip it to make it
+ // possible for the model to successfully identify the output node.
+ string sanitized_output_name = output_name;
+ if (str_util::EndsWith(output_name, "]")) {
+ sanitized_output_name = output_name.substr(0, output_name.rfind('['));
+ }
std::shared_ptr<Node> output;
- auto it = lookup_table_.find(output_name);
+ mutex_lock l(mu_);
+ auto it = lookup_table_.find(sanitized_output_name);
if (it != lookup_table_.end()) {
output = it->second;
}
- std::shared_ptr<Node> node(new Node(id_counter_++, output));
+ std::shared_ptr<Node> node(new Node(id_counter_++, tokens.back(), output));
if (!output_) {
output_ = node;
}
@@ -281,107 +284,127 @@ std::shared_ptr<Node> Model::AddNode(const string& name,
output->add_input(node);
}
lookup_table_.insert(std::make_pair(name, node));
- return node;
}
-std::shared_ptr<Node> Model::LookupNode(const string& name) {
+void Model::AddProcessingTime(const string& name, int64 delta) {
tf_shared_lock l(mu_);
- std::shared_ptr<Node> result;
- auto it = lookup_table_.find(name);
- if (it != lookup_table_.end()) {
- result = it->second;
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ (*node)->add_processing_time(delta);
}
- return result;
}
-void Model::Optimize() {
- mutex_lock l(mu_);
- int64 processing_time = ProcessingTime();
- int64 num_cpus = port::NumSchedulableCPUs();
- std::vector<Node::Knob> knobs = CollectKnobs();
- // The optimization algorithm starts by setting all parallelism knobs to 1. It
- // then repeatedly identifies the knob that, when turned up by 1, decreases
- // the output time the most. This process is repeated until all knobs reach
- // the number of schedulable CPUs or the projected output time is less than or
- // equal to the processing time needed to produce an element divided by the
- // number of schedulable CPUs.
- for (auto& knob : knobs) {
- LOG(INFO) << knob.node->name() << " " << knob.processing_time;
- knob.value = 1;
- knob.node->set_metadata("parallelism", knob.value);
+void Model::AddTunableParameter(const string& node_name,
+ const string& parameter_name,
+ std::atomic<int64>* value, int64 min, int64 max,
+ condition_variable* cond_var) {
+ tf_shared_lock l(mu_);
+ auto node = *gtl::FindOrNull(lookup_table_, node_name);
+ DCHECK(node);
+ node->add_tunable_param(parameter_name, value, min, max, cond_var);
+}
+
+// The optimization algorithm starts by setting all tunable parallelism
+// parameters to 1. It then repeatedly identifies the parameter whose increase
+// in parallelism decreases the output time the most. This process is repeated
+// until all parameters reach their maximum values or the projected output time
+// is less than or equal to the processing time needed to produce an element
+// divided by CPU budget.
+void Model::Optimize(int64 cpu_budget) {
+ tf_shared_lock lock(mu_);
+ std::vector<std::shared_ptr<Model::Node::Tunable>> tunables;
+ const int64 processing_time = ProcessingTime();
+ tunables = CollectTunables();
+ for (auto tunable : tunables) {
+ tunable->value = 1;
}
while (true) {
- int64 output_time = OutputTime();
- bool all_knobs = true;
- for (auto knob : knobs) {
- if (knob.value < num_cpus) {
- all_knobs = false;
+ const int64 output_time = OutputTime();
+ bool all_tunables = true;
+ for (auto& tunable : tunables) {
+ if (tunable->value < tunable->max) {
+ all_tunables = false;
break;
}
}
- if (output_time < processing_time / num_cpus || all_knobs) {
+ if (output_time < processing_time / cpu_budget || all_tunables) {
break;
}
int64 best_delta = -1;
- int best_knob = -1;
- for (int i = 0; i < knobs.size(); ++i) {
- if (knobs[i].value == num_cpus) {
+ Model::Node::Tunable* best_tunable = nullptr;
+ for (auto& tunable : tunables) {
+ if (tunable->value == tunable->max) {
continue;
}
- knobs[i].node->set_metadata("parallelism", knobs[i].value + 1);
+ tunable->value++;
int64 delta = output_time - OutputTime();
if (delta > best_delta) {
best_delta = delta;
- best_knob = i;
+ best_tunable = tunable.get();
}
- knobs[i].node->set_metadata("parallelism", knobs[i].value);
+ tunable->value--;
}
- knobs[best_knob].value++;
- knobs[best_knob].node->set_metadata("parallelism", knobs[best_knob].value);
+ if (!best_tunable) {
+ // NOTE: This can happen because we are performing the optimization
+ // while the model data is changing. If this becomes an issue, we should
+ // look into performing the optimization using a model snapshot.
+ break;
+ }
+ best_tunable->value++;
}
- for (auto knob : knobs) {
- LOG(INFO) << knob.node->name() << " " << knob.value;
+ VLOG(2) << "Number of knobs: " << tunables.size();
+ for (auto& tunable : tunables) {
+ VLOG(2) << "Setting tunable parameter: " << tunable->value;
+ tunable->value_ptr->store(tunable->value);
+ if (tunable->cond_var) {
+ tunable->cond_var->notify_all();
+ }
}
- LOG(INFO) << "output time: " << OutputTime();
- LOG(INFO) << "processing time: " << ProcessingTime();
}
-void Model::OutputToFile() {
- proto::Model model_proto;
- ToProto(&model_proto);
- string filename;
- Env::Default()->LocalTempFilename(&filename);
- TF_CHECK_OK(WriteStringToFile(Env::Default(), filename,
- model_proto.SerializeAsString()));
- LOG(INFO) << filename;
+void Model::RecordElement(const string& name) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ (*node)->record_element();
+ }
}
-void Model::RemoveNode(const string& prefix) {
- mutex_lock l(mu_);
- lookup_table_.erase(prefix);
+void Model::RecordStart(const string& name, bool stop_output) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ if (stop_output && (*node)->output()) {
+ (*node)->output()->record_stop();
+ }
+ (*node)->record_start();
+ }
}
-void Model::ToProto(proto::Model* model_proto) {
- mutex_lock l(mu_);
- model_proto->set_id_counter(id_counter_);
- model_proto->set_output(output_->id());
- AddNodeToProto(output_, model_proto);
+void Model::RecordStop(const string& name, bool start_output) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ (*node)->record_stop();
+ if (start_output && (*node)->output()) {
+ (*node)->output()->record_start();
+ }
+ }
}
-// static
-void Model::AddNodeToProto(const std::shared_ptr<Node>& node,
- proto::Model* model_proto) {
- proto::Node* node_proto = model_proto->add_node();
- node->ToProto(node_proto);
- for (const std::shared_ptr<Node>& input : node->inputs()) {
- AddNodeToProto(input, model_proto);
+void Model::RemoveNode(const string& name) {
+ mutex_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node && (*node)->output()) {
+ (*node)->output()->remove_input(*node);
}
+ lookup_table_.erase(name);
}
-std::vector<Node::Knob> Model::CollectKnobs() {
- std::vector<Node::Knob> knobs;
- output_->CollectKnobs(&knobs);
- return knobs;
+std::vector<std::shared_ptr<Model::Node::Tunable>> Model::CollectTunables() {
+ std::vector<std::shared_ptr<Model::Node::Tunable>> tunables;
+ output_->CollectTunables(&tunables);
+ return tunables;
}
int64 Model::OutputTime() {
diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h
index 98172909bf..26402f5cd3 100644
--- a/tensorflow/core/framework/model.h
+++ b/tensorflow/core/framework/model.h
@@ -22,9 +22,9 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/core/framework/model.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
@@ -33,356 +33,364 @@ namespace tensorflow {
namespace data {
namespace model {
-class Model;
-class Node;
-
-// Abstract representation of a TensorFlow input pipeline node. It collects
-// information about inputs to this node, processing time spent executing the
-// node logic, number of elements produced by the node, various other
-// information (e.g. batch size or execution parallelism).
+// Abstract representation of a TensorFlow input pipeline that can be used
+// for collecting runtime information and optimizing performance. It collects
+// runtime information about execution of the input pipeline that is used to
+// create a performance model, which is in turn used to identify optimal values
+// of tunable parameters.
//
// Developers of tf.data transformations are not expected to interact with this
// class directly. Boiler plate code for creating the abstract representation of
-// the input pipeline and collecting common information has been added to the
+// the input pipeline and collecting runtime information has been added to the
// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
-//
-// In addition, `DatasetBaseIterator` provides wrappers that can be used for
-// transformation-specific information collection. The `SetMetadata` wrapper can
-// be used to pass arbitrary metadata to the modeling framework, while the
-// `StartWork` and `StopWork` wrappers should be used to correctly account for
-// processing time of multi-threaded transformation that yield the CPU; such
-// transformations should invoke `StartWork()` when a transformation thread
-// starts executing (e.g. when created or woken up) and `StopWork()` when a
-// transformation thread stops executing (e.g. when returning or waiting).
-//
-// TODO(jsimsa): Create an API to capture the abstract semantics of each
-// tf.data transformation and replace switch-case blocks with inheritance.
-class Node {
+class Model {
public:
- Node(int64 id, std::shared_ptr<Node> output) : id_(id), output_(output) {}
-
- explicit Node(const proto::Node& node_proto) : id_(node_proto.id()) {
- name_ = node_proto.name();
- type_ = TypeFromName(node_proto.name());
- processing_time_ = node_proto.processing_time();
- num_elements_ = node_proto.num_elements();
- metadata_.insert(node_proto.metadata().begin(),
- node_proto.metadata().end());
- }
-
- // Records that the node produced an element.
- void add_element() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- num_elements_++;
- }
-
- // Adds an input.
- void add_input(std::shared_ptr<Node> node) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- inputs_.push_back(node);
- }
-
- // Increments the aggregate processing time by the given delta.
- void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- processing_time_ += delta;
- }
-
- // Returns the unique node ID.
- int64 id() LOCKS_EXCLUDED(mu_) { return id_; }
-
- // Returns the node inputs.
- std::list<std::shared_ptr<Node>> inputs() LOCKS_EXCLUDED(mu_) {
- tf_shared_lock l(mu_);
- return inputs_;
- }
-
- // Returns the node name.
- const string& name() LOCKS_EXCLUDED(mu_) {
- tf_shared_lock l(mu_);
- return name_;
- }
-
- // Returns the number of elements produced by the node.
- int64 num_elements() LOCKS_EXCLUDED(mu_) {
- tf_shared_lock l(mu_);
- return num_elements_;
- }
-
- // Returns the node output.
- std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
- tf_shared_lock l(mu_);
- return output_;
- }
-
- // Removes an input.
- void remove_input(std::shared_ptr<Node> input) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- inputs_.remove(input);
- }
-
- // Adds the given key-value pair to the node metadata.
- void set_metadata(const string& key, int64 value) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- metadata_[key] = value;
- }
-
- // Sets the node name.
- void set_name(const string& name) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- name_ = name;
- type_ = TypeFromName(name);
- }
-
- // Set the node output.
- void set_output(std::shared_ptr<Node> output) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- output_ = output;
- }
-
- // Records that a node thread has started work.
- void start_work() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos();
- }
-
- // Records that a node thread has stopped work.
- void stop_work() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- auto iter = work_start_.find(std::this_thread::get_id());
- CHECK(work_start_.end() != iter)
- << "Encountered a stop event that was not preceded by a start event.";
- processing_time_ += Env::Default()->NowNanos() - iter->second;
- work_start_.erase(iter);
- }
+ Model() = default;
- private:
- // Represents a performance knob.
- struct Knob {
- Node* node;
- int64 processing_time;
- int64 value;
- };
+ // Adds a constant parameter for the given node.
+ void AddConstantParameter(const string& node_name,
+ const string& parameter_name, int64 value)
+ LOCKS_EXCLUDED(mu_);
- enum class Type {
- BATCH = 0,
- CACHE,
- CONCATENATE,
- FILTER,
- FLAT_MAP,
- INTERLEAVE,
- MAP,
- MAP_AND_BATCH,
- PADDED_BATCH,
- PARALLEL_INTERLEAVE,
- PARALLEL_INTERLEAVE_V2,
- PARALLEL_MAP,
- PREFETCH,
- REPEAT,
- SHUFFLE,
- SKIP,
- TAKE,
- ZIP,
- UNKNOWN,
- };
+ // Adds a node with the given name and given output (identified by name).
+ void AddNode(const string& name, const string& output_name)
+ LOCKS_EXCLUDED(mu_);
- // Collects performance knobs in the subtree rooted in this node.
- void CollectKnobs(std::vector<Node::Knob>* knobs) LOCKS_EXCLUDED(mu_);
+ // Increments the processing time for the given node..
+ void AddProcessingTime(const string& name, int64 delta) LOCKS_EXCLUDED(mu_);
- // Returns the per-element processing time spent in this node.
- int64 NanosPerElement() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- return NanosPerElementLocked();
- }
+ // Adds a tunable parameter for the given node.
+ void AddTunableParameter(const string& node_name,
+ const string& parameter_name,
+ std::atomic<int64>* value, int64 min, int64 max,
+ condition_variable* cond_var) LOCKS_EXCLUDED(mu_);
- int64 NanosPerElementLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (num_elements_ == 0) {
- return 0;
- }
- return (int64)((double)processing_time_ / (double)num_elements_);
- }
-
- // Returns the per-element output time for this node.
- int64 OutputTime(std::vector<int64>* input_times) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- return OutputTimeLocked(input_times);
- }
-
- int64 OutputTimeLocked(std::vector<int64>* input_times)
- EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
- int64 OutputTimeForInputs(std::vector<int64>* input_times)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- int64 sum = 0;
- for (auto input : inputs_) {
- sum += input->OutputTime(input_times);
- }
- return sum;
- }
-
- // Returns the per-element processing time spent in the subtree rooted in this
- // node.
- int64 ProcessingTime() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- return ProcessingTimeLocked();
- }
-
- int64 ProcessingTimeLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
- // Returns the per-element processing time spent in the inputs of this node.
- int64 ProcessingTimeForInputs() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- int64 sum = 0;
- for (auto input : inputs_) {
- sum += input->ProcessingTimeLocked();
- }
- return sum;
- }
-
- // Serializes the node state into the given proto.
- void ToProto(proto::Node* node_proto) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- node_proto->set_id(id_);
- node_proto->set_name(name_);
- node_proto->set_num_elements(num_elements_);
- node_proto->set_processing_time(processing_time_);
- for (const std::shared_ptr<Node>& input : inputs_) {
- node_proto->add_input(input->id());
- }
- if (output_) {
- node_proto->set_output(output_->id());
- }
- node_proto->mutable_metadata()->insert(metadata_.begin(), metadata_.end());
- }
+ // Runs optimization.
+ void Optimize(int64 cpu_budget) LOCKS_EXCLUDED(mu_);
- Type TypeFromName(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (name_ == "Batch") {
- return Type::BATCH;
- }
- if (str_util::EndsWith(name_, "Cache")) {
- return Type::CACHE;
- }
- if (name_ == "Concatenate") {
- return Type::CONCATENATE;
- }
- if (name_ == "Filter") {
- return Type::FILTER;
- }
- if (name_ == "FlatMap") {
- return Type::FLAT_MAP;
- }
- if (name_ == "Interleave") {
- return Type::INTERLEAVE;
- }
- if (name_ == "Map") {
- return Type::MAP;
+ // Records that a node has produced an element.
+ void RecordElement(const string& name) LOCKS_EXCLUDED(mu_);
+
+ // Records that the given node has started work. If `stop_output` is set, it
+ // also records that the output of the given node has stopped work.
+ void RecordStart(const string& name, bool stop_output) LOCKS_EXCLUDED(mu_);
+
+ // Records that the given node has stopped work. If `stop_output` is set, it
+ // also records that the output of the given node has started work.
+ void RecordStop(const string& name, bool start_output) LOCKS_EXCLUDED(mu_);
+
+ // Removes the given node.
+ void RemoveNode(const string& name) LOCKS_EXCLUDED(mu_);
+
+ private:
+ // Abstract representation of a TensorFlow input pipeline node. It collects
+ // information about inputs to this node, processing time spent executing the
+ // node logic, number of elements produced by the node, various other
+ // information (e.g. batch size or execution parallelism).
+ //
+ // Developers of tf.data transformations are not expected to interact with
+ // this class directly. Boiler plate code for creating the abstract
+ // representation of the input pipeline and collecting common information has
+ // been added to the implementation of `DatasetBase` and `DatasetBaseIterator`
+ // respectively.
+ //
+ // In addition, `DatasetBaseIterator` provides wrappers that can be used for
+ // transformation-specific information collection. The `SetMetadata` wrapper
+ // can be used to pass arbitrary metadata to the modeling framework, while the
+ // `StartWork` and `StopWork` wrappers should be used to correctly account for
+ // processing time of multi-threaded transformation that yield the CPU; such
+ // transformations should invoke `StartWork()` when a transformation thread
+ // starts executing (e.g. when created or woken up) and `StopWork()` when a
+ // transformation thread stops executing (e.g. when returning or waiting).
+ //
+ // TODO(jsimsa): Create an API to capture the abstract semantics of each
+ // tf.data transformation and replace switch-case blocks with inheritance.
+ class Node {
+ public:
+ // Represents a tunable parameter.
+ struct Tunable {
+ Tunable(std::atomic<int64>* value, int64 min, int64 max,
+ condition_variable* cond_var)
+ : value(*value),
+ min(min),
+ max(max),
+ value_ptr(value),
+ cond_var(cond_var) {}
+
+ // Identifies the model value of the parameter. This can be different from
+ // the actual value (e.g. during optimization search).
+ int64 value;
+
+ // Identifies the minimum value of the parameter.
+ int64 min;
+
+ // Identifies the maximum value of the parameter.
+ int64 max;
+
+ // Points to the actual value of the parameter. Not owned.
+ std::atomic<int64>* value_ptr;
+
+ // If non-null, this condition variable is notified when the model updates
+ // the actual value of the parameter (via `value_ptr`). Not owned.
+ condition_variable* cond_var;
+ };
+
+ Node(int64 id, const string& name, std::shared_ptr<Node> output)
+ : id_(id), name_(name), type_(TypeFromName(name)), output_(output) {}
+
+ // Adds a constant parameter.
+ void add_constant_param(const string& name, int64 value)
+ LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ constant_params_[name] = value;
}
- if (name_ == "MapAndBatch") {
- return Type::MAP_AND_BATCH;
+
+ // Adds an input.
+ void add_input(std::shared_ptr<Node> node) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ inputs_.push_back(node);
}
- if (name_ == "PaddedBatch") {
- return Type::PADDED_BATCH;
+
+ // Increments the aggregate processing time by the given delta.
+ void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ processing_time_ += delta;
}
- if (name_ == "ParallelInterleave") {
- return Type::PARALLEL_INTERLEAVE;
+
+ // Adds a tunable parameter.
+ void add_tunable_param(const string& name, std::atomic<int64>* value,
+ int64 min, int64 max, condition_variable* cond_var)
+ LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ tunable_params_[name] =
+ std::make_shared<Tunable>(value, min, max, cond_var);
}
- if (name_ == "ParallelInterleaveV2") {
- return Type::PARALLEL_INTERLEAVE_V2;
+
+ // Returns the unique node ID.
+ int64 id() LOCKS_EXCLUDED(mu_) { return id_; }
+
+ // Returns the node inputs.
+ std::list<std::shared_ptr<Node>> inputs() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return inputs_;
}
- if (name_ == "ParallelMap") {
- return Type::PARALLEL_MAP;
+
+ // Returns the node name.
+ const string& name() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return name_;
}
- if (name_ == "Prefetch") {
- return Type::PREFETCH;
+
+ // Returns the number of elements produced by the node.
+ int64 num_elements() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return num_elements_;
}
- if (str_util::EndsWith(name_, "Repeat")) {
- return Type::REPEAT;
+
+ // Returns the node output.
+ std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return output_;
}
- if (name_ == "Shuffle") {
- return Type::SHUFFLE;
+
+ // Records that the node produced an element.
+ void record_element() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ num_elements_++;
}
- if (str_util::EndsWith(name_, "Skip")) {
- return Type::SKIP;
+
+ // Records that a node thread has started executing.
+ void record_start() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos();
}
- if (str_util::EndsWith(name_, "Take")) {
- return Type::TAKE;
+
+ // Records that a node thread has stopped executing.
+ void record_stop() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ std::thread::id tid = std::this_thread::get_id();
+ auto start_time = gtl::FindOrNull(work_start_, tid);
+ DCHECK(start_time)
+ << "Encountered a stop event that was not preceded by a start event.";
+ if (start_time) {
+ processing_time_ += Env::Default()->NowNanos() - *start_time;
+ work_start_.erase(tid);
+ }
}
- if (name_ == "Zip") {
- return Type::ZIP;
+
+ // Removes an input.
+ void remove_input(std::shared_ptr<Node> input) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ inputs_.remove(input);
}
- return Type::UNKNOWN;
- }
- mutex mu_;
- const int64 id_;
- Type type_ GUARDED_BY(mu_);
- string name_ GUARDED_BY(mu_);
- int64 processing_time_ GUARDED_BY(mu_) = 0;
- int64 num_elements_ GUARDED_BY(mu_) = 0;
- std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_);
- std::map<string, int64> metadata_ GUARDED_BY(mu_);
- std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_);
- std::shared_ptr<Node> output_ GUARDED_BY(mu_);
+ // Set the node output.
+ void set_output(std::shared_ptr<Node> output) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ output_ = output;
+ }
- friend class Model;
-};
+ // Collects tunable parameters in the subtree rooted in this node.
+ void CollectTunables(std::vector<std::shared_ptr<Tunable>>* tunables)
+ LOCKS_EXCLUDED(mu_);
-// Abstract representation of a TensorFlow input pipeline that can be used
-// for collecting runtime information and optimizing performance. It collects
-// runtime information about execution of the input pipeline that is used to
-// create a performance model, which is in turn used to identify optimal values
-// of performance knobs.
-//
-// Developers of tf.data transformations are not expected to interact with this
-// class directly. Boiler plate code for creating the abstract representation of
-// the input pipeline and collecting runtime information has been added to the
-// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
-//
-// TODO(jsimsa): Add a mechanism for feeding the result of the optimization
-// into the input pipeline.
-class Model {
- public:
- Model() = default;
- explicit Model(const proto::Model& model_proto);
+ // Returns the per-element output time for this node.
+ int64 OutputTime(std::vector<int64>* input_times) LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return OutputTimeLocked(input_times);
+ }
- ~Model() {}
+ // Returns the per-element processing time spent in the subtree rooted in
+ // this node.
+ int64 ProcessingTime() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return ProcessingTimeLocked();
+ }
- // Returns the model output node.
- std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
- tf_shared_lock l(mu_);
- return output_;
- }
+ private:
+ enum class Type {
+ BATCH = 0,
+ CACHE,
+ CONCATENATE,
+ FILTER,
+ FLAT_MAP,
+ INTERLEAVE,
+ MAP,
+ MAP_AND_BATCH,
+ PADDED_BATCH,
+ PARALLEL_INTERLEAVE,
+ PARALLEL_INTERLEAVE_V2,
+ PARALLEL_MAP,
+ PREFETCH,
+ REPEAT,
+ SHUFFLE,
+ SKIP,
+ TAKE,
+ ZIP,
+ UNKNOWN,
+ };
+
+ // Gets a value of the given parameter (tunable or constant).
+ int64 GetParameterValue(const string& name) SHARED_LOCKS_REQUIRED(mu_);
+
+ // Returns the per-element processing time spent in this node.
+ int64 NanosPerElement() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return NanosPerElementLocked();
+ }
- // Adds a node with the given name and given output (identified by name).
- std::shared_ptr<Node> AddNode(const string& name, const string& output_name)
- LOCKS_EXCLUDED(mu_);
+ int64 NanosPerElementLocked() SHARED_LOCKS_REQUIRED(mu_) {
+ if (num_elements_ == 0) {
+ return 0;
+ }
+ return (int64)((double)processing_time_ / (double)num_elements_);
+ }
- // Looks up the node using the given name.
- std::shared_ptr<Node> LookupNode(const string& name) LOCKS_EXCLUDED(mu_);
+ int64 OutputTimeLocked(std::vector<int64>* input_times)
+ SHARED_LOCKS_REQUIRED(mu_);
- // Runs optimization.
- void Optimize() LOCKS_EXCLUDED(mu_);
+ int64 OutputTimeForInputs(std::vector<int64>* input_times)
+ SHARED_LOCKS_REQUIRED(mu_) {
+ int64 sum = 0;
+ for (auto input : inputs_) {
+ sum += input->OutputTime(input_times);
+ }
+ return sum;
+ }
- // Outputs the state of a model to a file.
- //
- // TODO(jsimsa): Remove this method once the optimization loop is closed.
- void OutputToFile() LOCKS_EXCLUDED(mu_);
+ int64 ProcessingTimeLocked() SHARED_LOCKS_REQUIRED(mu_);
- // Removes the node identified by the given name.
- void RemoveNode(const string& prefix) LOCKS_EXCLUDED(mu_);
+ // Returns the per-element processing time spent in the inputs of this node.
+ int64 ProcessingTimeForInputs() SHARED_LOCKS_REQUIRED(mu_) {
+ int64 sum = 0;
+ for (auto input : inputs_) {
+ sum += input->ProcessingTime();
+ }
+ return sum;
+ }
- // Serializes the model state to the given proto.
- void ToProto(proto::Model* model_proto) LOCKS_EXCLUDED(mu_);
+ Type TypeFromName(const string& name) SHARED_LOCKS_REQUIRED(mu_) {
+ if (name_ == "Batch") {
+ return Type::BATCH;
+ }
+ if (str_util::EndsWith(name_, "Cache")) {
+ return Type::CACHE;
+ }
+ if (name_ == "Concatenate") {
+ return Type::CONCATENATE;
+ }
+ if (name_ == "Filter") {
+ return Type::FILTER;
+ }
+ if (name_ == "FlatMap") {
+ return Type::FLAT_MAP;
+ }
+ if (name_ == "Interleave") {
+ return Type::INTERLEAVE;
+ }
+ if (name_ == "Map") {
+ return Type::MAP;
+ }
+ if (name_ == "MapAndBatch") {
+ return Type::MAP_AND_BATCH;
+ }
+ if (name_ == "PaddedBatch") {
+ return Type::PADDED_BATCH;
+ }
+ if (name_ == "ParallelInterleave") {
+ return Type::PARALLEL_INTERLEAVE;
+ }
+ if (name_ == "ParallelInterleaveV2") {
+ return Type::PARALLEL_INTERLEAVE_V2;
+ }
+ if (name_ == "ParallelMap") {
+ return Type::PARALLEL_MAP;
+ }
+ if (name_ == "Prefetch") {
+ return Type::PREFETCH;
+ }
+ if (str_util::EndsWith(name_, "Repeat")) {
+ return Type::REPEAT;
+ }
+ if (name_ == "Shuffle") {
+ return Type::SHUFFLE;
+ }
+ if (str_util::EndsWith(name_, "Skip")) {
+ return Type::SKIP;
+ }
+ if (str_util::EndsWith(name_, "Take")) {
+ return Type::TAKE;
+ }
+ if (name_ == "Zip") {
+ return Type::ZIP;
+ }
+ return Type::UNKNOWN;
+ }
- private:
- static void AddNodeToProto(const std::shared_ptr<Node>& node,
- proto::Model* model_proto);
+ mutex mu_;
+ const int64 id_;
+ const string name_;
+ const Type type_;
+ int64 processing_time_ GUARDED_BY(mu_) = 0;
+ int64 num_elements_ GUARDED_BY(mu_) = 0;
+ std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_);
+ std::map<string, int64> constant_params_ GUARDED_BY(mu_);
+ // Tunables are shared with the model during optimization.
+ std::map<string, std::shared_ptr<Tunable>> tunable_params_ GUARDED_BY(mu_);
+ std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_);
+ std::shared_ptr<Node> output_ GUARDED_BY(mu_);
+ };
- std::vector<Node::Knob> CollectKnobs() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ std::vector<std::shared_ptr<Node::Tunable>> CollectTunables()
+ SHARED_LOCKS_REQUIRED(mu_);
- int64 OutputTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ int64 OutputTime() SHARED_LOCKS_REQUIRED(mu_);
- int64 ProcessingTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ int64 ProcessingTime() SHARED_LOCKS_REQUIRED(mu_);
+ // Used for coordination between different input pipeline threads. Exclusive
+ // access is required only when adding or removing nodes. Concurrent access to
+ // existing nodes is protected by a node mutex.
mutex mu_;
int64 id_counter_ GUARDED_BY(mu_) = 1;
std::shared_ptr<Node> output_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/framework/model.proto b/tensorflow/core/framework/model.proto
deleted file mode 100644
index 26000007af..0000000000
--- a/tensorflow/core/framework/model.proto
+++ /dev/null
@@ -1,30 +0,0 @@
-syntax = "proto3";
-
-package tensorflow.data.model.proto;
-option cc_enable_arenas = true;
-
-message Model {
- // Counter used for generating new node IDs.
- int64 id_counter = 1;
- // Nodes of this model.
- repeated Node node = 2;
- // The ID of the output node.
- int64 output = 3;
-};
-
-message Node {
- // The node ID.
- int64 id = 1;
- // The node name.
- string name = 2;
- // Input node IDs.
- repeated int64 input = 3;
- // Output node ID.
- int64 output = 4;
- // Number of elements produced by the node.
- int64 num_elements = 5;
- // The CPU time spent by running threads of this node.
- int64 processing_time = 6;
- // Key-value store for node metadata (e.g. batch size or parallelism).
- map<string, int32> metadata = 7;
-};
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index bacc1d72c4..43ac1d0ada 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -372,6 +372,14 @@ Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
node_def.name());
}
+Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
+ DataTypeVector* inputs) {
+ for (const auto& arg : op_def.input_arg()) {
+ TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
+ }
+ return Status::OK();
+}
+
Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
int output_port, DataType* output_type) {
DataTypeVector output_types;
@@ -397,12 +405,18 @@ Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
DataTypeVector* inputs, DataTypeVector* outputs) {
- for (const auto& arg : op_def.input_arg()) {
- TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
- }
+ TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs));
return OutputTypesForNode(node_def, op_def, outputs);
}
+Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
+ int* num_outputs) {
+ DataTypeVector outputs;
+ TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs));
+ *num_outputs = outputs.size();
+ return Status::OK();
+}
+
Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
if (node_def.op() != op_def.name()) {
return errors::InvalidArgument("NodeDef op '", node_def.op(),
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index 499034cab2..187bfa2c88 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -249,6 +249,10 @@ const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name);
// REQUIRES: ValidateOpDef(op_def).ok()
Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
int input_port, DataType* input_type);
+// Computes the input types for a specific node.
+// REQUIRES: ValidateOpDef(op_def).ok()
+Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
+ DataTypeVector* inputs);
// Computes the output type for a specific node output.
// REQUIRES: ValidateOpDef(op_def).ok()
Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
@@ -261,6 +265,10 @@ Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
// REQUIRES: ValidateOpDef(op_def).ok()
Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
DataTypeVector* inputs, DataTypeVector* outputs);
+// Computes the number of outputs for a specific node.
+// REQUIRES: ValidateOpDef(op_def).ok()
+Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
+ int* num_outputs);
// Validates that the NodeDef:
// * Defines all expected attrs from the OpDef.
diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc
index 74cc594863..d9d437024a 100644
--- a/tensorflow/core/framework/node_def_util_test.cc
+++ b/tensorflow/core/framework/node_def_util_test.cc
@@ -370,6 +370,48 @@ TEST(NodeDefUtilTest, ValidSyntax) {
"Illegal op input name 'a:00");
}
+TEST(InputTypesForNode, Simple) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
+ .Input("a: float")
+ .Input("b: int32")
+ .Output("c: string")
+ .Output("d: bool"));
+ const NodeDef node_def = ToNodeDef(
+ NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
+ DataTypeVector types;
+ EXPECT_TRUE(InputTypesForNode(node_def, op_def, &types).ok());
+ EXPECT_EQ(types[0], DT_FLOAT);
+ EXPECT_EQ(types[1], DT_INT32);
+
+ DataType type;
+ EXPECT_TRUE(InputTypeForNode(node_def, op_def, 0, &type).ok());
+ EXPECT_EQ(type, DT_FLOAT);
+ EXPECT_TRUE(InputTypeForNode(node_def, op_def, 1, &type).ok());
+ EXPECT_EQ(type, DT_INT32);
+ EXPECT_FALSE(InputTypeForNode(node_def, op_def, 2, &type).ok());
+}
+
+TEST(OutputTypesForNode, Simple) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
+ .Input("a: float")
+ .Input("b: int32")
+ .Output("c: string")
+ .Output("d: bool"));
+ const NodeDef node_def = ToNodeDef(
+ NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
+ DataTypeVector types;
+ EXPECT_TRUE(OutputTypesForNode(node_def, op_def, &types).ok());
+ EXPECT_EQ(types[0], DT_STRING);
+ EXPECT_EQ(types[1], DT_BOOL);
+
+ DataType type;
+ EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 0, &type).ok());
+ EXPECT_EQ(type, DT_STRING);
+ EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 1, &type).ok());
+ EXPECT_EQ(type, DT_BOOL);
+ EXPECT_FALSE(OutputTypeForNode(node_def, op_def, 2, &type).ok());
+}
+
TEST(NameRangesForNodeTest, Simple) {
const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
.Input("a: float")
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 80f2b12987..3e34bf0418 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -265,9 +265,12 @@ OpKernelContext::OpKernelContext(Params* params, int num_outputs)
params_->ensure_eigen_gpu_device();
if (params_->eigen_gpu_device != nullptr) {
Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
- params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device,
- params_->op_device_context,
- eigen_gpu_allocator);
+ Status s = params_->device->ReinitializeGpuDevice(
+ this, params_->eigen_gpu_device, params_->op_device_context,
+ eigen_gpu_allocator);
+ if (!s.ok()) {
+ SetStatus(s);
+ }
}
if (params_->record_tensor_accesses) {
referenced_tensors_.Init();
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index e752599de1..4bbd6c3d7d 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -372,18 +372,37 @@ class OpKernelConstruction {
template <typename ListType, typename ElementType>
class OpArgIterator {
public:
- typedef OpArgIterator<ListType, ElementType> ME;
+ using iterator_category = std::forward_iterator_tag;
+ using value_type = ElementType;
+ using pointer = ElementType*;
+ using reference = ElementType&;
+ using difference_type = ptrdiff_t;
+
OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
- bool operator==(const ME& rhs) {
+
+ bool operator==(const OpArgIterator& rhs) {
DCHECK(list_ == rhs.list_);
return i_ == rhs.i_;
}
- bool operator!=(const ME& rhs) {
+
+ bool operator!=(const OpArgIterator& rhs) {
DCHECK(list_ == rhs.list_);
return i_ != rhs.i_;
}
- void operator++() { ++i_; }
- ElementType& operator*() { return (*list_)[i_]; }
+
+ OpArgIterator operator++() { // prefix ++it
+ ++i_;
+ return *this;
+ }
+
+ OpArgIterator operator++(int) { // postfix it++
+ OpArgIterator old_value = *this;
+ ++i_;
+ return old_value;
+ }
+
+ reference operator*() { return (*list_)[i_]; }
+ pointer operator->() { return &(*list_)[i_]; }
private:
const ListType* const list_;
@@ -394,7 +413,7 @@ class OpArgIterator {
// that are passed to the op as a single named argument.
class OpInputList {
public:
- typedef OpArgIterator<OpInputList, const Tensor&> Iterator;
+ typedef OpArgIterator<OpInputList, const Tensor> Iterator;
OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpInputList(OpKernelContext* ctx, int start, int stop)
: ctx_(ctx), start_(start), stop_(stop) {}
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc
index ebdaaec153..508a8d3149 100644
--- a/tensorflow/core/framework/resource_mgr.cc
+++ b/tensorflow/core/framework/resource_mgr.cc
@@ -288,4 +288,13 @@ Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
return ctx->resource_manager()->Delete(p);
}
+Status ResourceHandlesShape(shape_inference::InferenceContext* c) {
+ int n;
+ TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
+ for (int i = 0; i < n; ++i) {
+ c->set_output(i, c->Scalar());
+ }
+ return Status::OK();
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index d58deaa3fc..abb6635984 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
#define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
+#include <memory>
#include <string>
#include <typeindex>
#include <typeinfo>
@@ -127,6 +128,14 @@ class ResourceMgr {
Status Lookup(const string& container, const string& name,
T** resource) const TF_MUST_USE_RESULT;
+ // Similar to Lookup, but looks up multiple resources at once, with only a
+ // single lock acquisition.
+ template <typename T>
+ Status LookupMany(absl::Span<std::pair<const string*, const string*> const>
+ containers_and_names,
+ std::vector<std::unique_ptr<T, core::RefCountDeleter>>*
+ resource) const TF_MUST_USE_RESULT;
+
// If "container" has a resource "name", returns it in
// "*resource". Otherwise, invokes creator() to create the resource.
// The caller takes the ownership of one ref on "*resource".
@@ -246,6 +255,12 @@ Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
template <typename T>
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
+// Looks up multiple resources pointed by a sequence of resource handles.
+template <typename T>
+Status LookupResources(
+ OpKernelContext* ctx, absl::Span<ResourceHandle const> p,
+ std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values);
+
// Looks up or creates a resource.
template <typename T>
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
@@ -358,6 +373,26 @@ class ResourceHandleOp : public OpKernel {
std::atomic<bool> initialized_{false};
};
+// Utility op kernel to produce a handle to a resource of type T.
+template <typename T>
+class ResourceHandlesOp : public OpKernel {
+ public:
+ explicit ResourceHandlesOp(OpKernelConstruction* context);
+
+ void Compute(OpKernelContext* ctx) override;
+
+ bool IsExpensive() override { return false; }
+
+ private:
+ std::vector<string> containers_;
+ std::vector<string> names_;
+ mutex mutex_;
+ std::vector<Tensor> resources_;
+ std::atomic<bool> initialized_{false};
+};
+
+Status ResourceHandlesShape(shape_inference::InferenceContext* c);
+
// Registers a kernel for an op which produces a handle to a resource of the
// specified type.
#define REGISTER_RESOURCE_HANDLE_KERNEL(Type) \
@@ -390,6 +425,24 @@ Status ResourceMgr::Lookup(const string& container, const string& name,
}
template <typename T>
+Status ResourceMgr::LookupMany(
+ absl::Span<std::pair<const string*, const string*> const>
+ containers_and_names,
+ std::vector<std::unique_ptr<T, core::RefCountDeleter>>* resources) const {
+ CheckDeriveFromResourceBase<T>();
+ tf_shared_lock l(mu_);
+ resources->resize(containers_and_names.size());
+ for (size_t i = 0; i < containers_and_names.size(); ++i) {
+ T* resource;
+ TF_RETURN_IF_ERROR(LookupInternal(*containers_and_names[i].first,
+ *containers_and_names[i].second,
+ &resource));
+ (*resources)[i].reset(resource);
+ }
+ return Status::OK();
+}
+
+template <typename T>
Status ResourceMgr::LookupInternal(const string& container, const string& name,
T** resource) const {
ResourceBase* found = nullptr;
@@ -499,6 +552,19 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
}
template <typename T>
+Status LookupResources(
+ OpKernelContext* ctx, absl::Span<ResourceHandle const* const> p,
+ std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values) {
+ std::vector<std::pair<const string*, const string*>> containers_and_names(
+ p.size());
+ for (size_t i = 0; i < p.size(); ++i) {
+ TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, *p[i]));
+ containers_and_names[i] = {&p[i]->container(), &p[i]->name()};
+ }
+ return ctx->resource_manager()->LookupMany(containers_and_names, values);
+}
+
+template <typename T>
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
T** value, std::function<Status(T**)> creator) {
TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
@@ -555,6 +621,46 @@ void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) {
ctx->set_output(0, resource_);
}
+template <typename T>
+ResourceHandlesOp<T>::ResourceHandlesOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ int n;
+ OP_REQUIRES_OK(context, context->GetAttr("N", &n));
+ OP_REQUIRES_OK(context, context->GetAttr("containers", &containers_));
+ OP_REQUIRES_OK(context, context->GetAttr("shared_names", &names_));
+ OP_REQUIRES(
+ context, containers_.size() == n,
+ errors::InvalidArgument("Number of containers (", containers_.size(),
+ ") must be equal to N (", n, ")"));
+ OP_REQUIRES(context, names_.size() == n,
+ errors::InvalidArgument("Number of names (", containers_.size(),
+ ") must be equal to N (", n, ")"));
+ resources_.resize(n);
+}
+
+template <typename T>
+void ResourceHandlesOp<T>::Compute(OpKernelContext* ctx) {
+ if (!initialized_.load()) {
+ mutex_lock ml(mutex_);
+ // Checking again to see if another thread has initialized the resource.
+ if (!initialized_.load()) {
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+ for (size_t i = 0; i < resources_.size(); ++i) {
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
+ &resources_[i], attr));
+ ResourceHandle h =
+ MakeResourceHandle<T>(ctx, containers_[i], names_[i]);
+ resources_[i].template scalar<ResourceHandle>()() = h;
+ }
+ initialized_.store(true);
+ }
+ }
+ for (size_t i = 0; i < resources_.size(); ++i) {
+ ctx->set_output(i, resources_[i]);
+ }
+}
+
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 516afa517d..1dea6da911 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -812,6 +812,28 @@ Tensor Tensor::Slice(int64 start, int64 limit) const {
return ret;
}
+Tensor Tensor::SubSlice(int64 index) const {
+ CHECK_GE(dims(), 1); // Crash ok.
+ CHECK_LE(0, index); // Crash ok.
+ int64 dim0_size = shape_.dim_size(0);
+ CHECK_LE(index, dim0_size); // Crash ok.
+ Tensor ret;
+ ret.shape_ = shape_;
+ ret.shape_.RemoveDim(0);
+ ret.set_dtype(dtype());
+ ret.buf_ = nullptr;
+ if (dim0_size > 0) {
+ const int64 elems_per_dim0 = NumElements() / dim0_size;
+ const int64 delta = index * elems_per_dim0;
+ const int64 num_elems = elems_per_dim0;
+ if (buf_) {
+ DataType dt = dtype();
+ CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
+ }
+ }
+ return ret;
+}
+
bool Tensor::FromProto(const TensorProto& proto) {
return FromProto(cpu_allocator(), proto);
}
@@ -948,9 +970,69 @@ void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
}
}
+// Appends the spacing between elements for a given dim onto a result string
+void PrintDimSpacing(int dim_index, int num_dims, string* result) {
+ if (dim_index == num_dims - 1) {
+ strings::StrAppend(result, " ");
+ return;
+ }
+ for (int j = 0; j < num_dims - dim_index - 1; j++) {
+ strings::StrAppend(result, "\n");
+ }
+ for (int j = 0; j <= dim_index; j++) {
+ strings::StrAppend(result, " ");
+ }
+}
+
+// Print from left dim to right dim recursively.
+template <typename T>
+void PrintOneDimV2(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
+ int64 num_elts_at_ends, int num_dims, const T* data,
+ int64 data_index, string* result) {
+ // We have recursed beyond all the dimensions into a single element
+ // of the tensor.
+ if (dim_index == num_dims) {
+ strings::StrAppend(result, PrintOneElement(data[data_index]));
+ return;
+ }
+
+ strings::StrAppend(result, "[");
+ int64 element_count = shape[dim_index];
+ int64 start_of_end =
+ std::max(num_elts_at_ends, element_count - num_elts_at_ends);
+
+ // Loop every element of one dim.
+ int64 elements_per_iter = 1;
+ for (int i = dim_index + 1; i < num_dims; i++) {
+ elements_per_iter *= shape[i];
+ }
+ for (int64 i = 0; (i < num_elts_at_ends) && (i < element_count); i++) {
+ if (i > 0) {
+ PrintDimSpacing(dim_index, num_dims, result);
+ }
+
+ // As for each element, print the sub-dim.
+ PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
+ data_index + elements_per_iter * i, result);
+ }
+ if (element_count > 2 * num_elts_at_ends) {
+ PrintDimSpacing(dim_index, num_dims, result);
+ strings::StrAppend(result, "...");
+ }
+ for (int64 i = start_of_end; i < element_count; i++) {
+ // As for each element, print the sub-dim.
+ PrintDimSpacing(dim_index, num_dims, result);
+ PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
+ data_index + elements_per_iter * i, result);
+ }
+
+ strings::StrAppend(result, "]");
+}
+
template <typename T>
string SummarizeArray(int64 limit, int64 num_elts,
- const TensorShape& tensor_shape, const char* data) {
+ const TensorShape& tensor_shape, const char* data,
+ const bool print_v2) {
string ret;
const T* array = reinterpret_cast<const T*>(data);
@@ -963,17 +1045,26 @@ string SummarizeArray(int64 limit, int64 num_elts,
if (num_elts > limit) strings::StrAppend(&ret, "...");
return ret;
}
- int64 data_index = 0;
- const int shape_size = tensor_shape.dims();
- PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
+ if (print_v2) {
+ const int num_dims = tensor_shape.dims();
+ PrintOneDimV2(0, shape, limit, num_dims, array, 0, &ret);
+ } else {
+ int64 data_index = 0;
+ const int shape_size = tensor_shape.dims();
+ PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
+
+ if (num_elts > limit) strings::StrAppend(&ret, "...");
+ }
- if (num_elts > limit) strings::StrAppend(&ret, "...");
return ret;
}
} // namespace
-string Tensor::SummarizeValue(int64 max_entries) const {
+string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const {
const int64 num_elts = NumElements();
+ if (max_entries < 0) {
+ max_entries = num_elts;
+ }
size_t limit = std::min(max_entries, num_elts);
if ((limit > 0) && (buf_ == nullptr)) {
return strings::StrCat("uninitialized Tensor of ", num_elts,
@@ -982,50 +1073,54 @@ string Tensor::SummarizeValue(int64 max_entries) const {
const char* data = limit > 0 ? tensor_data().data() : nullptr;
switch (dtype()) {
case DT_HALF:
- return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data);
+ return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
+ print_v2);
break;
case DT_FLOAT:
- return SummarizeArray<float>(limit, num_elts, shape_, data);
+ return SummarizeArray<float>(limit, num_elts, shape_, data, print_v2);
break;
case DT_DOUBLE:
- return SummarizeArray<double>(limit, num_elts, shape_, data);
+ return SummarizeArray<double>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT32:
- return SummarizeArray<uint32>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint32>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT32:
- return SummarizeArray<int32>(limit, num_elts, shape_, data);
+ return SummarizeArray<int32>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT8:
case DT_QUINT8:
- return SummarizeArray<uint8>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint8>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT16:
case DT_QUINT16:
- return SummarizeArray<uint16>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint16>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT16:
case DT_QINT16:
- return SummarizeArray<int16>(limit, num_elts, shape_, data);
+ return SummarizeArray<int16>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT8:
case DT_QINT8:
- return SummarizeArray<int8>(limit, num_elts, shape_, data);
+ return SummarizeArray<int8>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT64:
- return SummarizeArray<uint64>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint64>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT64:
- return SummarizeArray<int64>(limit, num_elts, shape_, data);
+ return SummarizeArray<int64>(limit, num_elts, shape_, data, print_v2);
break;
case DT_BOOL:
// TODO(tucker): Is it better to emit "True False..."? This
// will emit "1 0..." which is more compact.
- return SummarizeArray<bool>(limit, num_elts, shape_, data);
+ return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
break;
default: {
// All irregular cases
string ret;
+ if (print_v2) {
+ strings::StrAppend(&ret, "[");
+ }
// TODO(irving): Don't call flat every time around this
// loop.
for (size_t i = 0; i < limit; ++i) {
@@ -1045,6 +1140,9 @@ string Tensor::SummarizeValue(int64 max_entries) const {
}
}
if (max_entries < num_elts) strings::StrAppend(&ret, "...");
+ if (print_v2) {
+ strings::StrAppend(&ret, "]");
+ }
return ret;
}
}
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 696fd277cd..d0f9eb56e2 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -154,7 +154,7 @@ class Tensor {
/// Returns the estimated memory usage of this tensor.
size_t TotalBytes() const;
- // Returns the size of sallocated memory for this tensor.
+ // Returns the size of allocated memory for this tensor.
size_t AllocatedBytes() const;
/// Returns true iff this tensor is aligned.
@@ -200,10 +200,29 @@ class Tensor {
/// must check the returned tensor's alignment before calling certain
/// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
///
+ /// NOTE: When fed with an N-dimensional tensor, this method returns a tensor
+ /// also with N dimensions. If you want to select a sub tensor, see SubSlice.
+ ///
/// REQUIRES: `dims()` >= 1
/// REQUIRES: `0 <= dim0_start <= dim0_limit <= dim_size(0)`
Tensor Slice(int64 dim0_start, int64 dim0_limit) const;
+ /// \brief Select a subslice from this tensor along the 1st dimension.
+ ///
+ /// When fed with an N-dimensional tensor, this method returns a tensor with
+ /// N-1 dimensions, where the returned tensor is a subslice of the input
+ /// tensor along the first dimension. The N-1 dimensions of the returned
+ /// tensor are the last N-1 dimensions of the input tensor.
+ ///
+ /// NOTE: The returned tensor may not satisfy the same alignment
+ /// requirement as this tensor depending on the shape. The caller
+ /// must check the returned tensor's alignment before calling certain
+ /// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
+ ///
+ /// REQUIRES: `dims()` >= 1
+ /// REQUIRES: `0 <= dim0_start < dim_size(0)`
+ Tensor SubSlice(int64 index) const;
+
/// \brief Parse `other` and construct the tensor.
/// Returns `true` iff the parsing succeeds. If the parsing fails,
@@ -430,7 +449,7 @@ class Tensor {
int64 begin) const;
/// Render the first `max_entries` values in `*this` into a string.
- string SummarizeValue(int64 max_entries) const;
+ string SummarizeValue(int64 max_entries, bool print_v2 = false) const;
/// A human-readable summary of the tensor suitable for debugging.
string DebugString() const;
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
index 9a78cdc91e..c596604143 100644
--- a/tensorflow/core/framework/tensor_test.cc
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -1228,6 +1228,45 @@ TEST(Tensor, Slice_Basic) {
}
}
+TEST(Tensor, SubSlice_Basic) {
+ { // General
+ Tensor x(DT_FLOAT, TensorShape({10, 4, 36}));
+ // Fills in known values.
+ for (int i = 0; i < 10; ++i) {
+ x.SubSlice(i).flat<float>().setConstant(i * 1.f);
+ }
+ // A simple sub-slice along dim0.
+ Tensor y = x.SubSlice(5);
+ EXPECT_TRUE(y.shape().IsSameSize(TensorShape({4, 36})));
+ auto tx = x.tensor<float, 3>();
+ auto ty = y.tensor<float, 2>();
+ for (int j = 0; j < 4; ++j) {
+ for (int k = 0; k < 36; ++k) {
+ EXPECT_EQ(ty(j, k), 5.0);
+ EXPECT_EQ(&tx(5, j, k), &ty(j, k));
+ }
+ }
+ Tensor z = y.SubSlice(3).SubSlice(31);
+ auto tz = z.unaligned_flat<float>();
+ EXPECT_EQ(*tz.data(), 5.0);
+ }
+ {
+ // Test unaligned access via a SubSlice.
+ Tensor x(DT_FLOAT, TensorShape({30, 5}));
+ x.flat<float>().setConstant(0.0);
+
+ // Take an unaligned subslice.
+ Tensor y = x.SubSlice(1);
+#if EIGEN_MAX_ALIGN_BYTES > 0
+ EXPECT_FALSE(y.IsAligned());
+#endif
+ y.unaligned_flat<float>().setConstant(1.0);
+ for (int64 i = 0; i < y.NumElements(); ++i) {
+ EXPECT_EQ(1.0, y.unaligned_flat<float>()(i));
+ }
+ }
+}
+
template <typename T>
Tensor MkTensor(DataType dt, const TensorShape& shape,
std::vector<T> init_values) {
@@ -1295,6 +1334,63 @@ TEST(SummarizeValue, STRING) {
EXPECT_EQ("one two three four five one...", x.SummarizeValue(6));
}
+TEST(SummarizeValue, INT32_PRINT_V2) {
+ Tensor x = MkTensor<int>(DT_INT32, TensorShape({5}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(-1, true));
+ EXPECT_EQ("[1 2 ... 4 0]", x.SummarizeValue(2, true));
+ EXPECT_EQ("[1 ... 0]", x.SummarizeValue(1, true));
+ x = MkTensor<int>(DT_INT32, TensorShape({2, 2}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[1 2]\n [3 4]]", x.SummarizeValue(16, true));
+ x = MkTensor<int>(DT_INT32, TensorShape({2, 2, 1, 1}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[[[1]]\n\n [[2]]]\n\n\n [[[3]]\n\n [[4]]]]",
+ x.SummarizeValue(16, true));
+ x = MkTensor<int>(DT_INT32, TensorShape({0}), {});
+ EXPECT_EQ("[]", x.SummarizeValue(16, true));
+}
+
+TEST(SummarizeValue, INT32Dims_PRINT_V2) {
+ Tensor x = MkTensor<int>(DT_INT32, TensorShape({3, 4}),
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ EXPECT_EQ("[[1 ... 4]\n ...\n [9 ... 12]]", x.SummarizeValue(1, true));
+ EXPECT_EQ("[[1 2 3 4]\n [5 6 7 8]\n [9 10 11 12]]",
+ x.SummarizeValue(10, true));
+ EXPECT_EQ("[[1 2 3 4]\n [5 6 7 8]\n [9 10 11 12]]",
+ x.SummarizeValue(-1, true));
+}
+
+TEST(SummarizeValue, FLOAT_PRINT_V2) {
+ Tensor x = MkTensor<float>(DT_FLOAT, TensorShape({5}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(-1, true));
+ EXPECT_EQ("[1 2 ... 4 0]", x.SummarizeValue(2, true));
+ EXPECT_EQ("[1 ... 0]", x.SummarizeValue(1, true));
+ x = MkTensor<float>(DT_FLOAT, TensorShape({2, 2}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[1 2]\n [3 4]]", x.SummarizeValue(16, true));
+ x = MkTensor<float>(DT_FLOAT, TensorShape({2, 2, 1, 1}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[[[1]]\n\n [[2]]]\n\n\n [[[3]]\n\n [[4]]]]",
+ x.SummarizeValue(16, true));
+ x = MkTensor<float>(DT_FLOAT, TensorShape({0}), {});
+ EXPECT_EQ("[]", x.SummarizeValue(16, true));
+}
+
+TEST(SummarizeValue, BOOL_PRINT_V2) {
+ Tensor x = MkTensor<bool>(DT_BOOL, TensorShape({5}), {false, true, true});
+ EXPECT_EQ("[0 1 1 0 1]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[0 1 1 0 1]", x.SummarizeValue(-1, true));
+ EXPECT_EQ("[0 1 ... 0 1]", x.SummarizeValue(2, true));
+}
+
+TEST(SummarizeValue, STRING_PRINT_V2) {
+ Tensor x = MkTensor<string>(DT_STRING, TensorShape({5}),
+ {"one", "two", "three", "four", "five"});
+ EXPECT_EQ("[one two three four five]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[one two three four five]", x.SummarizeValue(-1, true));
+ x = MkTensor<string>(DT_STRING, TensorShape({5, 1, 5}),
+ {"one", "two", "three", "four", "five"});
+ EXPECT_EQ("[one two three four five one...]", x.SummarizeValue(6, true));
+}
+
void BM_CreateAndDestroy(int iters) {
TensorShape shape({10, 20});
while (--iters) {
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 7399613f6a..eeb5c14eaa 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -1162,7 +1162,9 @@ Status GraphConstructor::PopulateMissingUnusedInputMapKeys() {
const NodeDef* node_def = node_defs_[pair->second.gdef_index];
const OpDef* op_def;
TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
- if (key.second >= op_def->output_arg_size()) {
+ int num_outputs;
+ TF_RETURN_IF_ERROR(NumOutputsForNode(*node_def, *op_def, &num_outputs));
+ if (key.second >= num_outputs) {
// key's index out of bounds
missing_unused_input_map_keys_->push_back(key);
}
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index 73142ebde7..3eef6bd2bd 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -199,6 +199,10 @@ REGISTER_OP("TestOneInputOneOutput")
.Output("y: T")
.Attr("T: {float, int64}")
.SetShapeFn(shape_inference::UnchangedShape);
+REGISTER_OP("TestVariadicOutput")
+ .Output("outputs: N * int32")
+ .Attr("N: int >= 0")
+ .SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("TestDefaultAttr")
.Attr("default_int: int=31415")
.SetShapeFn(shape_inference::NoOutputs);
@@ -1463,12 +1467,15 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapMissingUnusedKeys) {
opts.input_map[TensorId("DNE", 0)] = TensorId("input", 0);
// Unused but not missing
opts.input_map[TensorId("t1", 0)] = TensorId("W1", 0);
+ // Unused but not missing
+ opts.input_map[TensorId("variadic", 4)] = TensorId("input", 0);
ExpectOK(
R"EOF(
node { name: 'W2' op: 'TestParams' }
node { name: 'new_input' op: 'TestInput' input: [ '^W2' ] }
node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] }
- node { name: 't2' op: 'TestMul' input: [ 't1:0', 't1:0' ] }
+ node { name: 'variadic' op: 'TestVariadicOutput'
+ attr { key: "N" value { i: 5 } } }
)EOF",
opts, &refiner, &results);
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index f5b0105862..06d3fefef1 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -977,7 +977,9 @@ std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_;
// nodes. Do not change the ordering of the Mkl passes.
const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
+#endif // ENABLE_MKL
//////////////////////////////////////////////////////////////////////////
// Helper functions for creating new node
@@ -2448,6 +2450,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.tanh = "Tanh";
csinfo_.tanh_grad = "TanhGrad";
csinfo_.reshape = "Reshape";
+ csinfo_.slice = "Slice";
csinfo_.softmax = "Softmax";
csinfo_.split = "Split";
// Element-wise ops. Ensure you also add any new ops to IsOpElementWise
@@ -2555,6 +2558,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.reshape,
mkl_op_registry::GetMklOpName(csinfo_.reshape),
CopyAttrsReshape, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.slice,
+ mkl_op_registry::GetMklOpName(csinfo_.slice),
+ CopyAttrsSlice, AlwaysRewrite});
rinfo_.push_back({csinfo_.softmax,
mkl_op_registry::GetMklOpName(csinfo_.softmax),
CopyAttrsDataType, AlwaysRewrite});
@@ -2674,6 +2680,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string tanh;
string tanh_grad;
string reshape;
+ string slice;
string softmax;
string split;
string squared_difference;
@@ -3132,6 +3139,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsSlice(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb);
// Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
@@ -3150,7 +3158,9 @@ MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
// nodes. Do not change the ordering of the Mkl passes.
const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
+#endif // ENABLE_MKL
//////////////////////////////////////////////////////////////////////////
// Helper functions for creating new node
@@ -3735,6 +3745,19 @@ void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
nb->Attr("Tshape", Tshape);
}
+void MklLayoutRewritePass::CopyAttrsSlice(const Node* orig_node,
+ NodeBuilder* nb) {
+ DataType T;
+ DataType Index;
+
+ // Get all attributes from old node.
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Index", &Index));
+ // Add attributes to new node.
+ nb->Attr("T", T);
+ nb->Attr("Index", Index);
+}
+
void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node,
NodeBuilder* nb) {
DataType T;
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index e8bac847e5..77640e287c 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/graph/mkl_layout_pass.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
@@ -3510,6 +3510,26 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
"B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1");
}
+TEST_F(MklLayoutPassTest, NodeRewrite_Slice_DeviceTest) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Int32Input'}"
+ "node { name: 'C' op: 'Int32Input'}"
+ "node { name: 'D' op: 'Slice'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'Index' value { type: DT_INT32 } }"
+ " input: ['A', 'B', 'C'] }"
+ "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'D'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Int32Input);C(Int32Input);"
+ "D(_MklSlice);DMT/_0(Const);DMT/_1(Const);DMT/"
+ "_2(Const);E(Zeta)|A->D;A->E;"
+ "A:control->DMT/_0:control;A:control->DMT/"
+ "_1:control;A:control->DMT/_2:control;"
+ "B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
+}
+
/////////////////////////////////////////////////////////////////////
// Post-rewrite fixup pass test
@@ -3586,4 +3606,4 @@ BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
} // namespace tensorflow
-#endif /* INTEL_MKL */
+#endif // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index b67a321fc1..8c5ffd71a3 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -133,7 +133,9 @@ class MklToTfConversionPass : public GraphOptimizationPass {
// complete picture of inputs and outputs of the nodes in the graphs.
const OptimizationPassRegistry::Grouping kMklTfConvPassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass);
+#endif // ENABLE_MKL
Status MklToTfConversionPass::InsertConversionNodeOnEdge(
std::unique_ptr<Graph>* g, Edge* e) {
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
index ebcb6de551..319437a801 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
@@ -304,4 +304,4 @@ BENCHMARK(BM_RunMklToTfConversionPass)->Arg(1000)->Arg(10000);
} // namespace
} // namespace tensorflow
-#endif /* INTEL_MKL */
+#endif // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index bd0284d43a..b00196f587 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -32,7 +32,7 @@ namespace test {
namespace graph {
// Converts "g" into its corresponding GraphDef "def".
-// DEPRECATED: call g->ToGraphDef(def) instead.
+ABSL_DEPRECATED("Call g->ToGraphDef(def) instead.")
void ToGraphDef(Graph* g, GraphDef* def);
// A few helpers to construct a graph.
diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc
index 7171ae059b..3b1d7d8347 100644
--- a/tensorflow/core/grappler/clusters/cluster.cc
+++ b/tensorflow/core/grappler/clusters/cluster.cc
@@ -83,6 +83,7 @@ void Cluster::DisableOptimizer(bool disable) {
rewriter_config->set_memory_optimization(RewriterConfig::NO_MEM_OPT);
rewriter_config->set_shape_optimization(RewriterConfig::OFF);
rewriter_config->set_remapping(RewriterConfig::OFF);
+ rewriter_config->set_pin_to_host_optimization(RewriterConfig::OFF);
rewriter_config->mutable_auto_parallel()->set_enable(false);
rewriter_config->clear_optimizers();
} else {
diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc
index b97603c890..e4f6bf7c86 100644
--- a/tensorflow/core/grappler/clusters/single_machine.cc
+++ b/tensorflow/core/grappler/clusters/single_machine.cc
@@ -93,13 +93,13 @@ Status SingleMachine::Provision() {
strings::StrCat("Not able to parse GPU device name: ", dev.name()));
}
TfGpuId tf_gpu_id(parsed.id);
- CudaGpuId cuda_gpu_id;
- Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+ PlatformGpuId platform_gpu_id;
+ Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
if (!s.ok()) {
return errors::Unavailable("Unknown TF GPU device with id ",
tf_gpu_id.value(), ": ", s.ToString());
}
- attr = GetLocalGPUInfo(cuda_gpu_id);
+ attr = GetLocalGPUInfo(platform_gpu_id);
} else if (dev.device_type().find("XLA") == string::npos) {
// Filter out the fake XLA devices to avoid double counting the actual
// hardware resources that are available.
diff --git a/tensorflow/core/grappler/clusters/utils.cc b/tensorflow/core/grappler/clusters/utils.cc
index a7519725a5..567e7c075e 100644
--- a/tensorflow/core/grappler/clusters/utils.cc
+++ b/tensorflow/core/grappler/clusters/utils.cc
@@ -70,13 +70,14 @@ DeviceProperties GetLocalCPUInfo() {
return device;
}
-DeviceProperties GetLocalGPUInfo(CudaGpuId cuda_gpu_id) {
+DeviceProperties GetLocalGPUInfo(PlatformGpuId platform_gpu_id) {
DeviceProperties device;
device.set_type("GPU");
#if GOOGLE_CUDA
cudaDeviceProp properties;
- cudaError_t error = cudaGetDeviceProperties(&properties, cuda_gpu_id.value());
+ cudaError_t error =
+ cudaGetDeviceProperties(&properties, platform_gpu_id.value());
if (error != cudaSuccess) {
device.set_type("UNKNOWN");
LOG(ERROR) << "Failed to get device properties, error code: " << error;
@@ -122,15 +123,15 @@ DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device) {
} else if (device.type == "GPU") {
if (device.has_id) {
TfGpuId tf_gpu_id(device.id);
- CudaGpuId cuda_gpu_id;
- Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+ PlatformGpuId platform_gpu_id;
+ Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
if (!s.ok()) {
LOG(ERROR) << s;
return unknown;
}
- return GetLocalGPUInfo(cuda_gpu_id);
+ return GetLocalGPUInfo(platform_gpu_id);
} else {
- return GetLocalGPUInfo(CudaGpuId(0));
+ return GetLocalGPUInfo(PlatformGpuId(0));
}
}
return unknown;
diff --git a/tensorflow/core/grappler/clusters/utils.h b/tensorflow/core/grappler/clusters/utils.h
index ca15c48006..f0a342b728 100644
--- a/tensorflow/core/grappler/clusters/utils.h
+++ b/tensorflow/core/grappler/clusters/utils.h
@@ -28,7 +28,7 @@ DeviceProperties GetLocalCPUInfo();
// Returns the DeviceProperties for the specified GPU attached to the server on
// which grappler is running.
-DeviceProperties GetLocalGPUInfo(CudaGpuId cuda_gpu_id);
+DeviceProperties GetLocalGPUInfo(PlatformGpuId platform_gpu_id);
// Returns the DeviceProperties of the specified device
DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device);
diff --git a/tensorflow/core/grappler/clusters/utils_test.cc b/tensorflow/core/grappler/clusters/utils_test.cc
index 74218adbac..3863d62980 100644
--- a/tensorflow/core/grappler/clusters/utils_test.cc
+++ b/tensorflow/core/grappler/clusters/utils_test.cc
@@ -31,22 +31,22 @@ TEST(UtilsTest, GetLocalGPUInfo) {
LOG(INFO) << "CUDA is enabled.";
DeviceProperties properties;
- // Invalid CUDA GPU ID.
- properties = GetLocalGPUInfo(CudaGpuId(100));
+ // Invalid platform GPU ID.
+ properties = GetLocalGPUInfo(PlatformGpuId(100));
EXPECT_EQ("UNKNOWN", properties.type());
- // Succeed when a valid CUDA GPU id was inserted.
- properties = GetLocalGPUInfo(CudaGpuId(0));
+ // Succeed when a valid platform GPU id was inserted.
+ properties = GetLocalGPUInfo(PlatformGpuId(0));
EXPECT_EQ("GPU", properties.type());
EXPECT_EQ("NVIDIA", properties.vendor());
#else
LOG(INFO) << "CUDA is not enabled.";
DeviceProperties properties;
- properties = GetLocalGPUInfo(CudaGpuId(0));
+ properties = GetLocalGPUInfo(PlatformGpuId(0));
EXPECT_EQ("GPU", properties.type());
- properties = GetLocalGPUInfo(CudaGpuId(100));
+ properties = GetLocalGPUInfo(PlatformGpuId(100));
EXPECT_EQ("GPU", properties.type());
#endif
}
@@ -74,20 +74,20 @@ TEST(UtilsTest, GetDeviceInfo) {
EXPECT_EQ("NVIDIA", properties.vendor());
#endif
- // TF to CUDA GPU id mapping entry doesn't exist.
+ // TF to platform GPU id mapping entry doesn't exist.
device.has_id = true;
device.id = 0;
properties = GetDeviceInfo(device);
EXPECT_EQ("UNKNOWN", properties.type());
#if GOOGLE_CUDA
- // Invalid CUDA GPU id.
- GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId(0), CudaGpuId(100));
+ // Invalid platform GPU id.
+ GpuIdManager::InsertTfPlatformGpuIdPair(TfGpuId(0), PlatformGpuId(100));
properties = GetDeviceInfo(device);
EXPECT_EQ("UNKNOWN", properties.type());
- // Valid CUDA GPU id.
- GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId(1), CudaGpuId(0));
+ // Valid platform GPU id.
+ GpuIdManager::InsertTfPlatformGpuIdPair(TfGpuId(1), PlatformGpuId(0));
device.id = 1;
properties = GetDeviceInfo(device);
EXPECT_EQ("GPU", properties.type());
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
index 83434ea40f..5415324b48 100644
--- a/tensorflow/core/grappler/costs/utils.cc
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -209,13 +209,13 @@ DeviceProperties GetDeviceInfo(const string& device_str) {
if (DeviceNameUtils::ParseFullName(device_str, &parsed)) {
if (parsed.type == "GPU") {
TfGpuId tf_gpu_id(parsed.id);
- CudaGpuId cuda_gpu_id;
- Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+ PlatformGpuId platform_gpu_id;
+ Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
if (!s.ok()) {
// We are probably running simulation without linking cuda libraries.
- cuda_gpu_id = CudaGpuId(parsed.id);
+ platform_gpu_id = PlatformGpuId(parsed.id);
}
- return GetLocalGPUInfo(cuda_gpu_id);
+ return GetLocalGPUInfo(platform_gpu_id);
} else if (parsed.type == "CPU") {
return GetLocalCPUInfo();
}
diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc
index a6b6b6f8b2..0b8cb5e919 100644
--- a/tensorflow/core/grappler/graph_view.cc
+++ b/tensorflow/core/grappler/graph_view.cc
@@ -14,11 +14,44 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
namespace grappler {
+int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
+ for (int output_arg_id = 0; output_arg_id < op.output_arg_size();
+ ++output_arg_id) {
+ if (port_id < 0) {
+ return -1;
+ } else if (port_id == 0) {
+ return output_arg_id;
+ }
+
+ // Default is 1 port per output arg.
+ int n = 1;
+
+ const auto& output_arg = op.output_arg(output_arg_id);
+ if (!output_arg.number_attr().empty()) {
+ n = node.attr().at(output_arg.number_attr()).i();
+ } else if (!output_arg.type_list_attr().empty()) {
+ n = node.attr().at(output_arg.type_list_attr()).list().type_size();
+ }
+
+ if (n < 0) {
+ // This should never happen.
+ DCHECK_GE(n, 0);
+ return -1;
+ } else if (port_id < n) {
+ return output_arg_id;
+ }
+ port_id -= n;
+ }
+
+ return -1;
+}
+
GraphView::GraphView(GraphDef* graph) : graph_(graph) {
for (int i = 0; i < graph_->node_size(); i++) {
auto node = graph_->mutable_node(i);
@@ -39,7 +72,7 @@ void GraphView::AddUniqueNodeOrDie(NodeDef* node) {
void GraphView::AddFanouts(NodeDef* node) {
for (int i = 0; i < node->input_size(); ++i) {
OutputPort fanin;
- string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
+ const string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
fanin.node = nodes_[fanin_name];
InputPort input;
diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h
index ac260f85a0..ec946ca3b5 100644
--- a/tensorflow/core/grappler/graph_view.h
+++ b/tensorflow/core/grappler/graph_view.h
@@ -20,11 +20,21 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace grappler {
+// Map a node/op's output port_id to arg_id.
+//
+// The port_id refers to the n-th tensor of the node, while the arg_id refers to
+// the n-th arg of the op. These two can be different if an op's arg is a list
+// of tensors.
+//
+// We return -1 for any invalid port_id (i.e., no corresponding arg_id).
+int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
+
// A utility class to simplify the traversal of a GraphDef.
class GraphView {
public:
diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc
index 958eb921fb..3d7d2faf7c 100644
--- a/tensorflow/core/grappler/graph_view_test.cc
+++ b/tensorflow/core/grappler/graph_view_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/cc/ops/parsing_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
@@ -25,6 +26,88 @@ namespace {
class GraphViewTest : public ::testing::Test {};
+TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
+ ops::ShapeN b(s.WithOpName("b"), {a, a, a});
+
+ GraphDef graph_def;
+ TF_CHECK_OK(s.ToGraphDef(&graph_def));
+ GraphView graph_view(&graph_def);
+
+ const NodeDef& a_node_def = *graph_view.GetNode("a");
+ const NodeDef& b_node_def = *graph_view.GetNode("b");
+
+ const OpDef* a_op_def = nullptr;
+ const OpDef* b_op_def = nullptr;
+ EXPECT_TRUE(
+ OpRegistry::Global()->LookUpOpDef(a_node_def.op(), &a_op_def).ok());
+ EXPECT_TRUE(
+ OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
+
+ EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *a_op_def, 0));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *a_op_def, 1));
+
+ EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 0));
+ EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 1));
+ EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 2));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 3));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 4));
+}
+
+TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) {
+ for (int num_splits : {1, 2}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const<int64>(s.WithOpName("a"), 1, {10, 10});
+ ops::SparseSplit b(s.WithOpName("b"), a, a, a, a, num_splits);
+
+ GraphDef graph_def;
+ TF_CHECK_OK(s.ToGraphDef(&graph_def));
+ GraphView graph_view(&graph_def);
+
+ const NodeDef& b_node_def = *graph_view.GetNode("b");
+ const OpDef* b_op_def = nullptr;
+ EXPECT_TRUE(
+ OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
+
+ for (int port_id = 0; port_id <= num_splits * 3; ++port_id) {
+ int arg_id = -1;
+ if (port_id < num_splits * 3) {
+ arg_id = port_id / num_splits;
+ }
+ EXPECT_EQ(arg_id, OpOutputPortIdToArgId(b_node_def, *b_op_def, port_id));
+ }
+ }
+}
+
+TEST_F(GraphViewTest, ParseSingleExample) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const<string>(s.WithOpName("a"), "", {});
+ Output b = ops::Const<int64>(s.WithOpName("b"), 1, {1, 1});
+ ops::ParseSingleExample c(s.WithOpName("c"), a, {b, b}, 2, {"w", "x"},
+ {"y", "z"}, {DT_INT64, DT_INT64}, {{1}, {1}});
+
+ GraphDef graph_def;
+ TF_CHECK_OK(s.ToGraphDef(&graph_def));
+ GraphView graph_view(&graph_def);
+
+ const NodeDef& c_node_def = *graph_view.GetNode("c");
+
+ const OpDef* c_op_def = nullptr;
+ EXPECT_TRUE(
+ OpRegistry::Global()->LookUpOpDef(c_node_def.op(), &c_op_def).ok());
+
+ EXPECT_EQ(0, OpOutputPortIdToArgId(c_node_def, *c_op_def, 0));
+ EXPECT_EQ(0, OpOutputPortIdToArgId(c_node_def, *c_op_def, 1));
+ EXPECT_EQ(1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 2));
+ EXPECT_EQ(1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 3));
+ EXPECT_EQ(2, OpOutputPortIdToArgId(c_node_def, *c_op_def, 4));
+ EXPECT_EQ(2, OpOutputPortIdToArgId(c_node_def, *c_op_def, 5));
+ EXPECT_EQ(3, OpOutputPortIdToArgId(c_node_def, *c_op_def, 6));
+ EXPECT_EQ(3, OpOutputPortIdToArgId(c_node_def, *c_op_def, 7));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 8));
+}
+
TEST_F(GraphViewTest, BasicGraph) {
TrivialTestGraphInputYielder fake_input(4, 2, 2, false, {"/CPU:0", "/GPU:0"});
GrapplerItem item;
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index 029515ad3c..369046666d 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -192,9 +192,13 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
const string feed_name = NodeName(feed_node);
new_item->feed.emplace_back(feed_name, Tensor());
}
+ for (const auto& fetch_node : cfg.fetch_nodes) {
+ new_item->fetch.emplace_back(NodeName(fetch_node));
+ }
- // Attempt to detect the fetch node(s).
- if (meta_graph.collection_def().count("train_op") > 0) {
+ // Attempt to detect the fetch node(s) if they were not set explicitly.
+ if (new_item->fetch.empty() &&
+ meta_graph.collection_def().count("train_op") > 0) {
const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
if (nodes.has_node_list()) {
for (const auto& node : nodes.node_list().value()) {
diff --git a/tensorflow/core/grappler/grappler_item_builder.h b/tensorflow/core/grappler/grappler_item_builder.h
index aafd2fdcda..1698587f8c 100644
--- a/tensorflow/core/grappler/grappler_item_builder.h
+++ b/tensorflow/core/grappler/grappler_item_builder.h
@@ -49,6 +49,8 @@ struct ItemConfig {
bool prune_graph = false;
// Override feed nodes list.
std::set<string> feed_nodes;
+ // Override fetch nodes list.
+ std::set<string> fetch_nodes;
};
// Factory method for creating a GrapplerItem from a MetaGraphDef.
diff --git a/tensorflow/core/grappler/grappler_item_builder_test.cc b/tensorflow/core/grappler/grappler_item_builder_test.cc
index 4b90bf3038..d00981f174 100644
--- a/tensorflow/core/grappler/grappler_item_builder_test.cc
+++ b/tensorflow/core/grappler/grappler_item_builder_test.cc
@@ -313,6 +313,29 @@ TEST_F(GrapplerItemBuilderTest, FromGraphWithUnknownDimInSignatureInput) {
EXPECT_EQ(item2->feed[0].second.NumElements(), 1);
}
+TEST_F(GrapplerItemBuilderTest, ExplicitFeedAndFetch) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), 0);
+ auto y = ops::Const(s.WithOpName("y"), 1);
+ auto z = ops::Add(s.WithOpName("z"), x, y);
+
+ MetaGraphDef meta_graph;
+ TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def()));
+
+ ItemConfig config;
+ config.feed_nodes.insert("x");
+ config.fetch_nodes.insert("z");
+
+ std::unique_ptr<GrapplerItem> item =
+ GrapplerItemFromMetaGraphDef("0", meta_graph, config);
+ ASSERT_TRUE(item != nullptr);
+
+ EXPECT_EQ(item->feed.size(), 1);
+ EXPECT_EQ(item->fetch.size(), 1);
+ EXPECT_EQ(item->feed[0].first, "x");
+ EXPECT_EQ(item->fetch[0], "z");
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index f094c151e6..960d1addb3 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -8,10 +8,6 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
# Platform specific build config
load(
- "//tensorflow/core:platform/default/build_config.bzl",
- "tf_protos_grappler",
-)
-load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"if_static",
)
@@ -97,7 +93,6 @@ cc_library(
deps = [
":evaluation_utils",
":graph_optimizer",
- ":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -107,6 +102,7 @@ cc_library(
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/utils:symbolic_shapes",
],
)
@@ -261,7 +257,6 @@ cc_library(
":constant_folding",
":graph_optimizer",
":graph_optimizer_stage",
- ":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -270,6 +265,7 @@ cc_library(
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/utils:symbolic_shapes",
"//tensorflow/core/grappler/utils:topological_sort",
],
)
@@ -515,12 +511,14 @@ cc_library(
":custom_graph_optimizer_registry",
":debug_stripper",
":dependency_optimizer",
+ ":experimental_implementation_selector",
":function_optimizer",
":graph_optimizer",
":layout_optimizer",
":loop_optimizer",
":memory_optimizer",
":model_pruner",
+ ":pin_to_host_optimizer",
":remapper",
":scoped_allocator_optimizer",
":shape_optimizer",
@@ -647,7 +645,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":graph_optimizer",
- ":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@@ -657,6 +654,7 @@ cc_library(
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:frame",
+ "//tensorflow/core/grappler/utils:symbolic_shapes",
],
)
@@ -714,31 +712,6 @@ tf_cuda_cc_test(
)
cc_library(
- name = "symbolic_shapes",
- srcs = ["symbolic_shapes.cc"],
- hdrs = ["symbolic_shapes.h"],
- visibility = ["//visibility:public"],
- deps = [
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- ] + tf_protos_grappler(),
-)
-
-tf_cc_test(
- name = "symbolic_shapes_test",
- srcs = ["symbolic_shapes_test.cc"],
- deps = [
- ":symbolic_shapes",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- ],
-)
-
-cc_library(
name = "debug_stripper",
srcs = ["debug_stripper.cc"],
hdrs = [
@@ -911,3 +884,41 @@ tf_cc_test(
"//tensorflow/core/grappler/utils:grappler_test",
],
)
+
+cc_library(
+ name = "pin_to_host_optimizer",
+ srcs = ["pin_to_host_optimizer.cc"],
+ hdrs = [
+ "pin_to_host_optimizer.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_optimizer",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/utils:frame",
+ "//tensorflow/core/grappler/utils:symbolic_shapes",
+ "//tensorflow/core/grappler/utils:topological_sort",
+ ],
+)
+
+tf_cuda_cc_test(
+ name = "pin_to_host_optimizer_test",
+ srcs = ["pin_to_host_optimizer_test.cc"],
+ deps = [
+ ":pin_to_host_optimizer",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/utils:grappler_test",
+ ],
+)
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 11ce121cba..3388ee8035 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -35,8 +35,8 @@ limitations under the License.
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -276,7 +276,7 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
for (int i = 0; i < output->input_size(); ++i) {
auto input = output->input(i);
- string name = ParseNodeName(input, &position);
+ StringPiece name = ParseNodeNameAsStringPiece(input, &position);
if (name == node.name() && /*control input*/ position < 0) {
return true;
}
@@ -1325,38 +1325,26 @@ class RemoveNegationStage : public ArithmeticOptimizerStage {
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- const string node_name = node->name();
NodeDef* x;
NodeDef* y;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
bool updated = false;
- if (IsAdd(*node)) {
- if (IsNeg(*x)) {
- // (-a) + b = b - a
- node->set_op("Sub");
- node->mutable_input()->SwapElements(0, 1);
- node->set_input(1, x->input(0));
- node->add_input(AsControlDependency(x->name()));
- ctx().node_map->AddOutput(NodeName(x->input(0)), node_name);
- updated = true;
- } else if (IsNeg(*y)) {
- // a + (-b) = a - b
- node->set_op("Sub");
- node->set_input(1, y->input(0));
- node->add_input(AsControlDependency(y->name()));
- ctx().node_map->AddOutput(NodeName(y->input(0)), node_name);
- updated = true;
- }
- } else if (IsSub(*node)) {
- if (IsNeg(*y)) {
- // a - (-b) = a + b
- node->set_op("Add");
- node->set_input(1, y->input(0));
- node->add_input(AsControlDependency(y->name()));
- ctx().node_map->AddOutput(NodeName(y->input(0)), node_name);
- updated = true;
- }
+ if (IsNeg(*y)) {
+ // a - (-b) = a + b or a + (-b) = a - b
+ ForwardControlDependencies(node, {y});
+ ctx().node_map->UpdateInput(node->name(), node->input(1), y->input(0));
+ node->set_op(IsAdd(*node) ? "Sub" : "Add");
+ node->set_input(1, y->input(0));
+ updated = true;
+ } else if (IsAdd(*node) && IsNeg(*x)) {
+ // (-a) + b = b - a
+ ForwardControlDependencies(node, {x});
+ ctx().node_map->UpdateInput(node->name(), node->input(0), x->input(0));
+ node->set_op("Sub");
+ node->mutable_input()->SwapElements(0, 1);
+ node->set_input(1, x->input(0));
+ updated = true;
}
if (updated) {
AddToOptimizationQueue(node);
@@ -1580,7 +1568,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
for (NodeDef* output : outputs) {
if (IsControlInput(output->input(0))) continue;
int port;
- const string node_name = ParseNodeName(output->input(0), &port);
+ const StringPiece node_name =
+ ParseNodeNameAsStringPiece(output->input(0), &port);
if (node_name == node.name()) {
tails->insert(ChainLink(output, port));
} else {
@@ -1630,7 +1619,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
} else {
for (NodeDef* new_tail : ctx().node_map->GetOutputs(tail->name())) {
int port;
- const string node_name = ParseNodeName(new_tail->input(0), &port);
+ const StringPiece node_name =
+ ParseNodeNameAsStringPiece(new_tail->input(0), &port);
if (node_name != tail->name()) {
return Status::OK();
}
@@ -2379,26 +2369,24 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- const auto& p = ctx().graph_properties->GetInputProperties(node->name())[1];
- for (int i = 0; i < p.shape().dim_size(); ++i) {
- if (p.shape().dim(i).size() < 0) {
+ const auto& pow_props =
+ ctx().graph_properties->GetInputProperties(node->name())[1];
+ for (int i = 0; i < pow_props.shape().dim_size(); ++i) {
+ if (pow_props.shape().dim(i).size() < 0) {
// skip if p is is not fully defined.
return Status::OK();
}
}
- if (TensorShape::IsValid(p.shape()) && p.has_value()) {
- Tensor pow(p.dtype(), p.shape());
- if (!pow.FromProto(p.value())) {
+ if (TensorShape::IsValid(pow_props.shape()) && pow_props.has_value()) {
+ Tensor pow(pow_props.dtype(), pow_props.shape());
+ if (!pow.FromProto(pow_props.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
- p.value().DebugString());
+ pow_props.value().DebugString());
}
complex128 prev, curr;
for (int i = 0; i < pow.NumElements(); ++i) {
- if (!GetElementUnexhaustive(pow, i,
- {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_COMPLEX128},
- &curr)) {
+ if (!GetElementUnexhaustive(pow, i, {pow_props.dtype()}, &curr)) {
// input data type is not supported by Pow. Skip.
return Status::OK();
}
@@ -2411,12 +2399,19 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
NodeDef *x, *y;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
+ const auto& value_props =
+ ctx().graph_properties->GetInputProperties(node->name())[0];
+ const TensorShapeProto& output_shape =
+ ctx().graph_properties->GetOutputProperties(node->name())[0].shape();
if (curr == complex128(2, 0)) {
node->set_op("Square");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
- } else if (curr == complex128(1, 0)) {
+ } else if (curr == complex128(1, 0) &&
+ ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
+ // Pow could be used to broadcast, so make sure the shapes of the two
+ // arguments are identical before replacing Pow with Identity.
node->set_op("Identity");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
@@ -2426,20 +2421,20 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
- } else if (curr == complex128(0, 0)) {
- const auto& b =
- ctx().graph_properties->GetInputProperties(node->name())[0];
- for (int i = 0; i < b.shape().dim_size(); ++i) {
- if (b.shape().dim(i).size() < 0) {
+ } else if (curr == complex128(0, 0) &&
+ ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
+ for (int i = 0; i < value_props.shape().dim_size(); ++i) {
+ if (value_props.shape().dim(i).size() < 0) {
// skip if b is is not fully defined.
return Status::OK();
}
}
- if (TensorShape::IsValid(b.shape()) && b.has_value()) {
- Tensor base(b.dtype(), b.shape());
- if (!base.FromProto(b.value())) {
+ if (TensorShape::IsValid(value_props.shape()) &&
+ value_props.has_value()) {
+ Tensor base(value_props.dtype(), value_props.shape());
+ if (!base.FromProto(value_props.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
- b.value().DebugString());
+ value_props.value().DebugString());
}
node->set_op("Const");
Tensor c(base.dtype(), base.shape());
@@ -2597,12 +2592,10 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
~ConvertExpm1Stage() override = default;
bool IsSupported(const NodeDef* node) const override {
- if (!IsSub(*node))
- return false;
+ if (!IsSub(*node)) return false;
NodeDef* input;
- if (!GetInputNode(node->input(0), &input).ok())
- return false;
+ if (!GetInputNode(node->input(0), &input).ok()) return false;
return IsExp(*input);
}
@@ -2622,10 +2615,8 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
return Status::OK();
}
- const auto& t =
- ctx().graph_properties->GetInputProperties(exp->name())[0];
- const auto& c =
- ctx().graph_properties->GetInputProperties(node->name())[1];
+ const auto& t = ctx().graph_properties->GetInputProperties(exp->name())[0];
+ const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1];
for (int k = 0; k < c.shape().dim_size(); ++k) {
// Skip if c shape is not fully determined.
if (c.shape().dim(k).size() < 0) {
@@ -2940,8 +2931,8 @@ uint64 UniqueNodes::ComputeSignature(const NodeDef& node) const {
for (const auto& input : node.input()) {
int pos;
- string node_name = ParseNodeName(input, &pos);
- h = Hash64CombineUnordered(Hash64(node_name), h);
+ const StringPiece node_name = ParseNodeNameAsStringPiece(input, &pos);
+ h = Hash64CombineUnordered(Hash64(node_name.data(), node_name.size()), h);
h = Hash64CombineUnordered(std::hash<int>()(pos), h);
}
for (const auto& attr : node.attr()) {
@@ -3053,6 +3044,13 @@ void ArithmeticOptimizer::DedupComputations() {
return;
}
std::set<int> duplicates;
+ // Populate feed_inplace_op;
+ std::unordered_set<NodeDef*> feeds_inplace_op;
+ for (int i = 0; i < optimized_graph_->node_size(); ++i) {
+ if (FeedsInPlaceOp(graph_view, optimized_graph_->node(i))) {
+ feeds_inplace_op.insert(optimized_graph_->mutable_node(i));
+ }
+ }
do {
stop = true;
UniqueNodes nodes;
@@ -3061,19 +3059,19 @@ void ArithmeticOptimizer::DedupComputations() {
continue;
}
NodeDef* node = optimized_graph_->mutable_node(i);
- if (!CanDedup(*node)) {
+ if (!CanDedup(*node) ||
+ feeds_inplace_op.find(node) != feeds_inplace_op.end()) {
continue;
}
NodeDef* rep = nodes.FindOrAddRepresentative(node);
if (rep == node) {
continue;
}
- // If either node feeds an inplace op, deduping them may cause data races.
- // For example: If we dedup nodes initializing two independent inplace
- // accumulations, they will write to the same buffer, clobbering each
- // other's results.
- if (FeedsInPlaceOp(graph_view, *rep) ||
- FeedsInPlaceOp(graph_view, *node)) {
+ // If either node or rep feeds an inplace op, deduping them may cause data
+ // races. For example: If we dedup nodes initializing two independent
+ // inplace accumulations, they will write to the same buffer, clobbering
+ // each other's results.
+ if (feeds_inplace_op.find(rep) != feeds_inplace_op.end()) {
continue;
}
VLOG(3) << "Remove duplicated node: node=" << node->name()
@@ -3081,20 +3079,20 @@ void ArithmeticOptimizer::DedupComputations() {
const std::set<NodeDef*>& fanouts = node_map_->GetOutputs(node->name());
for (NodeDef* fanout : fanouts) {
for (int i = 0; i < fanout->input_size(); ++i) {
- string* name = fanout->mutable_input(i);
- int position;
- const string nodename = ParseNodeName(*name, &position);
- if (nodename == node->name()) {
- // Update name in-place.
- if (position > 0) {
- *name = StrCat(rep->name(), ":", position);
- } else if (position == 0) {
- *name = rep->name();
- } else {
- *name = StrCat("^", rep->name());
- }
- node_map_->AddOutput(rep->name(), fanout->name());
+ string* fanout_input = fanout->mutable_input(i);
+ const int position =
+ NodePositionIfSameNode(*fanout_input, node->name());
+ // Update name in-place.
+ if (position < -1) {
+ continue;
+ } else if (position > 0) {
+ *fanout_input = StrCat(rep->name(), ":", position);
+ } else if (position == 0) {
+ *fanout_input = rep->name();
+ } else {
+ *fanout_input = StrCat("^", rep->name());
}
+ node_map_->AddOutput(rep->name(), fanout->name());
}
}
duplicates.insert(i);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index bc838c6659..77f3c64c65 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -2353,9 +2353,14 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
Output sub_negx_y = ops::Sub(s.WithOpName("Sub_negx_y"), neg_x, y);
Output sub_x_negy = ops::Sub(s.WithOpName("Sub_x_negy"), x, neg_y);
Output sub_negx_negy = ops::Sub(s.WithOpName("Sub_negx_negy"), neg_x, neg_y);
- auto add_all = ops::AddN(s.WithOpName("add_all"),
- {add_x_y, add_negx_y, add_x_negy, add_negx_negy,
- sub_x_y, sub_negx_y, sub_x_negy, sub_negx_negy});
+ Output neg_x_with_dep = ops::Neg(
+ s.WithOpName("Neg_x_with_dep").WithControlDependencies({add_x_y}), x);
+ Output add_negx_with_dep_y =
+ ops::Add(s.WithOpName("Add_negx_with_dep_y"), neg_x_with_dep, y);
+ auto add_all =
+ ops::AddN(s.WithOpName("add_all"),
+ {add_x_y, add_negx_y, add_x_negy, add_negx_negy, sub_x_y,
+ sub_negx_y, sub_x_negy, sub_negx_negy, add_negx_with_dep_y});
GrapplerItem item;
item.fetch = {"add_all"};
@@ -2370,7 +2375,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveNegation(&optimizer);
- OptimizeAndPrune(&optimizer, &item, &output);
+ OptimizeTwice(&optimizer, &item, &output);
EXPECT_EQ(item.graph.node_size(), output.node_size());
int found = 0;
@@ -2379,42 +2384,43 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
if (node.name() == "Add_negx_y") {
++found;
EXPECT_EQ("Sub", node.op());
- EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("x", node.input(1));
- EXPECT_EQ("^Neg_x", node.input(2));
} else if (node.name() == "Add_x_negy") {
++found;
EXPECT_EQ("Sub", node.op());
- EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("y", node.input(1));
- EXPECT_EQ("^Neg_y", node.input(2));
} else if (node.name() == "Add_negx_negy") {
++found;
EXPECT_EQ("Sub", node.op());
- EXPECT_EQ(3, node.input_size());
- EXPECT_EQ("Neg_y", node.input(0));
- EXPECT_EQ("x", node.input(1));
- EXPECT_EQ("^Neg_x", node.input(2));
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("Neg_x", node.input(0));
+ EXPECT_EQ("y", node.input(1));
} else if (node.name() == "Sub_x_negy") {
++found;
EXPECT_EQ("Add", node.op());
- EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("y", node.input(1));
- EXPECT_EQ("^Neg_y", node.input(2));
} else if (node.name() == "Sub_negx_negy") {
++found;
EXPECT_EQ("Sub", node.op());
- EXPECT_EQ(4, node.input_size());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("y", node.input(0));
+ EXPECT_EQ("x", node.input(1));
+ } else if (node.name() == "Add_negx_with_dep_y") {
+ ++found;
+ EXPECT_EQ("Sub", node.op());
+ EXPECT_EQ(3, node.input_size());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("x", node.input(1));
- EXPECT_EQ("^Neg_y", node.input(2));
- EXPECT_EQ("^Neg_x", node.input(3));
+ EXPECT_EQ("^Add_x_y", node.input(2));
}
}
- EXPECT_EQ(5, found);
+ EXPECT_EQ(6, found);
auto tensors = EvaluateNodes(output, item.fetch, feed);
EXPECT_EQ(1, tensors.size());
@@ -2468,6 +2474,9 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
+ auto z = ops::Const(s.WithOpName("z"), {42.0f}, {});
+ auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
+ auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
@@ -2475,21 +2484,24 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
Output out = ops::Pow(s.WithOpName("out"), x, y);
+ Output out_bcast1 = ops::Pow(s.WithOpName("out_bcast1"), z, ones);
+ Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
GrapplerItem item;
- item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", "out_1", "out"};
+ item.fetch = {"out2", "out1", "out.5", "out0", "out_.5",
+ "out_1", "out", "out_bcast1", "out_bcast2"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
- EXPECT_EQ(7, tensors_expected.size());
+ EXPECT_EQ(9, tensors_expected.size());
GraphDef got;
ArithmeticOptimizer optimizer;
EnableOnlyConvertPow(&optimizer);
OptimizeAndPrune(&optimizer, &item, &got);
auto tensors = EvaluateNodes(got, item.fetch);
- EXPECT_EQ(7, tensors.size());
+ EXPECT_EQ(9, tensors.size());
- for (int i = 0; i < 7; ++i) {
+ for (int i = 0; i < tensors.size(); ++i) {
EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
}
@@ -2503,6 +2515,9 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
AddNode("y_.5", "Const", {}, {}, &want);
AddNode("y_1", "Const", {}, {}, &want);
AddNode("y", "Const", {}, {}, &want);
+ AddNode("z", "Const", {}, {}, &want);
+ AddNode("ones", "Const", {}, {}, &want);
+ AddNode("zeros", "Const", {}, {}, &want);
AddNode("out2", "Square", {"x", AsControlDependency("y2")}, {}, &want);
AddNode("out1", "Identity", {"x", AsControlDependency("y1")}, {}, &want);
AddNode("out.5", "Sqrt", {"x", AsControlDependency("y.5")}, {}, &want);
@@ -2511,6 +2526,8 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
AddNode("out_.5", "Rsqrt", {"x", AsControlDependency("y_.5")}, {}, &want);
AddNode("out_1", "Reciprocal", {"x", AsControlDependency("y_1")}, {}, &want);
AddNode("out", "Pow", {"x", "y"}, {}, &want);
+ AddNode("out_bcast1", "Pow", {"z", "ones"}, {}, &want);
+ AddNode("out_bcast2", "Pow", {"z", "zeros"}, {}, &want);
CompareGraphs(want, got);
}
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 99737a71eb..ca5d3a6dfd 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -32,8 +32,8 @@ limitations under the License.
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -437,25 +437,6 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
}
namespace {
-bool ShapesEqual(const TensorShapeProto& shape1,
- const TensorShapeProto& shape2) {
- if (shape1.unknown_rank() || shape2.unknown_rank()) {
- return false;
- }
- if (shape1.dim_size() != shape2.dim_size()) {
- return false;
- }
- for (int i = 0; i < shape1.dim_size(); ++i) {
- if (shape1.dim(i).size() != shape2.dim(i).size()) {
- return false;
- }
- if (shape1.dim(i).size() == -1 || shape2.dim(i).size() == -1) {
- return false;
- }
- }
- return true;
-}
-
bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
BCast::Vec* shape, int64* min_id) {
if (shape_node.op() == "Shape") {
@@ -2125,7 +2106,8 @@ bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
Tensor axis_t(DT_INT32, TensorShape({}));
NodeDef* axis_node = optimized_graph->add_node();
axis_node->set_name(OptimizedNodeName(*node, "_const_axis"));
- const int axis = node->attr().at("axis").i();
+ const int axis =
+ node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
!CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node)
.ok()) {
@@ -2348,7 +2330,8 @@ Status ConstantFolding::SimplifyArithmeticOperations(
properties.GetInputProperties(node->name())[1].shape();
const bool x_is_zero = IsZeros(*x);
const bool x_is_one = x_is_zero ? false : IsOnes(*x);
- const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
+ const bool y_matches_output_shape =
+ ShapesSymbolicallyEqual(output_shape, y_shape);
if (y_matches_output_shape &&
((is_mul && x_is_one) || (is_add && x_is_zero))) {
// 1 * y = y or 0 + y = y.
@@ -2378,7 +2361,8 @@ Status ConstantFolding::SimplifyArithmeticOperations(
properties.GetInputProperties(node->name())[0].shape();
const bool y_is_zero = IsZeros(*y);
const bool y_is_one = y_is_zero ? false : IsOnes(*y);
- const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
+ const bool x_matches_output_shape =
+ ShapesSymbolicallyEqual(output_shape, x_shape);
if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
((is_add || is_sub) && y_is_zero))) {
// x * 1 = x or x / 1 = x or x +/- 0 = x
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 2a19b3f95a..b09360a2c2 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -3015,37 +3015,48 @@ TEST_F(ConstantFoldingTest, TrivialPack) {
auto stack =
ops::Stack(scope.WithOpName("stack").WithControlDependencies({y}), {x},
ops::Stack::Axis(1));
+ auto stack_no_axis = ops::Stack(scope.WithOpName("stack_no_axis"), {x});
GrapplerItem item;
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
- item.fetch.push_back("stack");
+ item.fetch = {"stack", "stack_no_axis"};
ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
- EXPECT_EQ(5, output.node_size());
+ EXPECT_EQ(7, output.node_size());
+ int found = 0;
for (const auto& node : output.node()) {
if (node.name() == "stack") {
- EXPECT_EQ("stack", node.name());
EXPECT_EQ("ExpandDims", node.op());
EXPECT_EQ(3, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("ConstantFolding/stack_const_axis", node.input(1));
EXPECT_EQ("^y", node.input(2));
+ ++found;
+ } else if (node.name() == "stack_no_axis") {
+ EXPECT_EQ("ExpandDims", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("ConstantFolding/stack_no_axis_const_axis", node.input(1));
+ ++found;
} else if (node.name() == "ConstantFolding/stack_const_axis") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("^x", node.input(0));
+ ++found;
}
}
+ EXPECT_EQ(found, 3);
- std::vector<string> fetch = {"stack"};
+ std::vector<string> fetch = {"stack", "stack_no_axis"};
auto tensors_expected = EvaluateNodes(item.graph, fetch);
auto tensors = EvaluateNodes(output, fetch);
- EXPECT_EQ(1, tensors_expected.size());
- EXPECT_EQ(1, tensors.size());
+ EXPECT_EQ(2, tensors_expected.size());
+ EXPECT_EQ(2, tensors.size());
EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
+ EXPECT_EQ(tensors_expected[1].shape(), tensors[1].shape());
}
// The test does not evalute the optimized and original graphs to check if their
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index e84df10778..81c1bddf67 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -22,6 +22,7 @@ cc_library(
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -31,6 +32,7 @@ tf_cc_test(
visibility = ["//visibility:public"],
deps = [
":filter_fusion",
+ ":graph_test_utils",
":graph_utils",
"//tensorflow/core:framework",
"//tensorflow/core:test",
@@ -49,6 +51,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":graph_utils",
+ ":function_utils",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -67,6 +70,7 @@ tf_cc_test(
srcs = ["fusion_utils_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":function_utils",
":fusion_utils",
":graph_utils",
"//tensorflow/core:framework",
@@ -78,6 +82,41 @@ tf_cc_test(
)
cc_library(
+ name = "function_utils",
+ srcs = ["function_utils.cc"],
+ hdrs = [
+ "function_utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core:lib_internal",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "function_utils_test",
+ srcs = ["function_utils_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core/kernels:cast_op",
+ "//tensorflow/tools/graph_transforms:transform_utils",
+ ],
+)
+
+cc_library(
name = "graph_utils",
srcs = ["graph_utils.cc"],
hdrs = [
@@ -110,6 +149,62 @@ tf_cc_test(
)
cc_library(
+ name = "graph_test_utils",
+ testonly = 1,
+ srcs = ["graph_test_utils.cc"],
+ hdrs = [
+ "graph_test_utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core:testlib",
+ ] + tf_protos_all(),
+)
+
+cc_library(
+ name = "hoist_random_uniform",
+ srcs = ["hoist_random_uniform.cc"],
+ hdrs = [
+ "hoist_random_uniform.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ ":graph_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core:lib_internal",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "hoist_random_uniform_test",
+ srcs = ["hoist_random_uniform_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_test_utils",
+ ":graph_utils",
+ ":hoist_random_uniform",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ ] + tf_protos_all(),
+)
+
+cc_library(
name = "latency_all_edges",
srcs = ["latency_all_edges.cc"],
hdrs = [
@@ -137,7 +232,9 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
+ ":function_utils",
":graph_utils",
+ ":vectorization_utils",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:grappler_item",
@@ -218,7 +315,7 @@ cc_library(
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
- "//tensorflow/core:ptr_util",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -227,6 +324,7 @@ tf_cc_test(
srcs = ["map_and_filter_fusion_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_test_utils",
":graph_utils",
":map_and_filter_fusion",
"//tensorflow/core:framework",
@@ -256,6 +354,7 @@ cc_library(
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -264,6 +363,7 @@ tf_cc_test(
srcs = ["map_fusion_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_test_utils",
":graph_utils",
":map_fusion",
"//tensorflow/core:framework",
@@ -301,6 +401,7 @@ tf_cc_test(
srcs = ["map_parallelization_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_test_utils",
":graph_utils",
":map_parallelization",
"//tensorflow/core:framework",
@@ -384,6 +485,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":filter_fusion",
+ ":hoist_random_uniform",
":latency_all_edges",
":map_and_batch_fusion",
":map_and_filter_fusion",
@@ -409,3 +511,43 @@ tf_cc_test(
"//tensorflow/core/grappler:grappler_item",
],
)
+
+cc_library(
+ name = "vectorization_utils",
+ srcs = ["vectorization_utils.cc"],
+ hdrs = [
+ "vectorization_utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ ":graph_utils",
+ "@com_google_absl//absl/strings",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/optimizers/data/vectorization",
+ "//tensorflow/core/grappler/utils:functions",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "vectorization_utils_test",
+ srcs = ["vectorization_utils_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ ":vectorization_utils",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core/kernels:cast_op",
+ "//tensorflow/tools/graph_transforms:transform_utils",
+ ] + tf_protos_all(),
+)
diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
index c71aa6e804..1ad495bbad 100644
--- a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
@@ -43,19 +43,14 @@ NodeDef MakeFusedFilterNode(const NodeDef& first_filter_node,
fused_node.set_op("FilterDataset");
fused_node.add_input(first_filter_node.input(0));
- auto copy_attribute = [](const string& attribute_name, const NodeDef& from,
- NodeDef* to) {
- (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
- };
-
auto attr = first_filter_node.attr().at("predicate");
*attr.mutable_func()->mutable_name() = fused_function.signature().name();
(*fused_node.mutable_attr())["predicate"] = std::move(attr);
- copy_attribute("Targuments", first_filter_node, &fused_node);
+ graph_utils::CopyAttribute("Targuments", first_filter_node, &fused_node);
for (auto key : {"output_shapes", "output_types"})
- copy_attribute(key, second_filter_node, &fused_node);
+ graph_utils::CopyAttribute(key, second_filter_node, &fused_node);
return fused_node;
}
@@ -120,8 +115,8 @@ Status FilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
// functions, or make sure that optimization passes run after filter
// fusion.
TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_predicate));
- // TODO(prazek): we could also remove map functions from library if they
- // are not used anymore.
+ // TODO(b/116285210): we could also remove map functions from library if
+ // they are not used anymore.
nodes_to_delete.insert(first_filter_node->name());
nodes_to_delete.insert(second_filter_node->name());
}
diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
index 12b1924efd..c8becc5cc0 100644
--- a/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -28,14 +28,7 @@ namespace tensorflow {
namespace grappler {
namespace {
-NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "FilterDataset", {string(input_node_name)},
- {{"predicate", FunctionDefHelper::FunctionRef("IsZero")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
+using graph_tests_utils::MakeFilterNode;
TEST(FilterFusionTest, FuseTwoFilterIntoOne) {
using test::function::NDef;
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.cc b/tensorflow/core/grappler/optimizers/data/function_utils.cc
new file mode 100644
index 0000000000..311df15bc2
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils.cc
@@ -0,0 +1,176 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/strings/scanner.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+
+FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name,
+ const string& output, int position)
+ : node_name(node_name), node_output(output), position(position) {
+ full_str = strings::StrCat(node_name, ":", node_output, ":", position);
+}
+
+FunctionDefTensorDesc::FunctionDefTensorDesc(const string& input) {
+ // Parses node_name:node_output:position string into its components.
+ full_str = input;
+ StringPiece capture;
+ StringPiece remaining;
+
+ // Parse "node_name"
+ if (strings::Scanner(input)
+ .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
+ .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
+ .GetResult(&remaining, &capture)) {
+ node_name = string(capture.data(), capture.size());
+ }
+
+ // Parse "node_output" if it exists
+ if (strings::Scanner(remaining)
+ .OneLiteral(":")
+ .RestartCapture()
+ .One(strings::Scanner::LETTER)
+ .Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE)
+ .GetResult(&remaining, &capture)) {
+ node_output = string(capture.data(), capture.size());
+ }
+
+ // Parse "position" if it exists
+ if (strings::Scanner(remaining)
+ .OneLiteral(":")
+ .RestartCapture()
+ .Many(strings::Scanner::DIGIT)
+ .GetResult(nullptr, &capture)) {
+ CHECK(strings::safe_strto32(capture, &position));
+ }
+}
+
+// TODO(rachelim): Create a utility class similar to MutableGraphView for
+// FunctionDefs, and use that to manipulate functions. It'll be more
+// performant if we kept mappings of nodes->inputs/outputs, so that we don't
+// have to search over all nodes each time.
+// Note that we're not using GrapplerFunctionItem because it doesn't cover
+// some of our desired uses (eg changing the outputs of a function), and the
+// FunctionDef -> GraphDef conversion isn't really necessary in this case.
+void ReplaceReferences(const string& from, const string& to,
+ FunctionDef* func) {
+ for (NodeDef& n : *func->mutable_node_def()) {
+ std::replace(n.mutable_input()->begin(), n.mutable_input()->end(), from,
+ to);
+ }
+
+ for (auto& p : *func->mutable_ret()) {
+ if (p.second == from) {
+ p.second = to;
+ }
+ }
+}
+
+void AddFunctionOutputWithUniqueName(StringPiece prefix,
+ StringPiece output_tensor_name,
+ FunctionDef* function, DataType dt) {
+ string name = string(prefix);
+ int id = function->signature().output_arg_size();
+ while (ContainsFunctionOutputWithName(name, *function)) {
+ name = strings::StrCat(prefix, "/_", id);
+ ++id;
+ }
+ auto* output = function->mutable_signature()->mutable_output_arg()->Add();
+ output->set_name(name);
+ output->set_type(dt);
+
+ (*function->mutable_ret())[name] = string(output_tensor_name);
+}
+
+NodeDef* AddNode(StringPiece name, StringPiece op,
+ const std::vector<string>& inputs,
+ const std::vector<std::pair<string, AttrValue>>& attributes,
+ FunctionDef* fd) {
+ NodeDef* node = fd->add_node_def();
+ if (!name.empty()) {
+ node->set_name(string(name));
+ } else {
+ SetUniqueFunctionNodeName(op, fd, node);
+ }
+ node->set_op(string(op));
+ for (const string& input : inputs) {
+ node->add_input(input);
+ }
+ for (auto attr : attributes) {
+ (*node->mutable_attr())[attr.first] = attr.second;
+ }
+ return node;
+}
+
+bool ContainsFunctionNodeWithName(StringPiece name,
+ const FunctionDef& function) {
+ return FindFunctionNodeWithName(name, function) != -1;
+}
+
+bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
+ return FindFunctionNodeWithOp(op, function) != -1;
+}
+
+bool ContainsFunctionOutputWithName(StringPiece name,
+ const FunctionDef& function) {
+ return FindFunctionOutputWithName(name, function) != -1;
+}
+
+int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) {
+ return graph_utils::GetFirstElementIndexWithPredicate(
+ [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
+ function.signature().input_arg());
+}
+
+int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) {
+ return graph_utils::GetFirstElementIndexWithPredicate(
+ [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
+ function.signature().output_arg());
+}
+
+int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
+ return graph_utils::GetFirstElementIndexWithPredicate(
+ [&name](const NodeDef& node) { return node.name() == name; },
+ function.node_def());
+}
+
+int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
+ return graph_utils::GetFirstElementIndexWithPredicate(
+ [&op](const NodeDef& node) { return node.op() == op; },
+ function.node_def());
+}
+
+void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
+ NodeDef* node) {
+ string name = string(prefix);
+ int id = function->node_def_size();
+ while (ContainsFunctionNodeWithName(name, *function)) {
+ name = strings::StrCat(prefix, "/_", id);
+ ++id;
+ }
+ node->set_name(std::move(name));
+}
+
+} // end namespace function_utils
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.h b/tensorflow/core/grappler/optimizers/data/function_utils.h
new file mode 100644
index 0000000000..d4ce824652
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils.h
@@ -0,0 +1,108 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+// This namespace contains utility functions for querying and modifying
+// FunctionDefs.
+
+// Describes a FunctionDef input tensor. In FunctionDefs, input tensor strings
+// have the format node_name:node_output:position (if they derive from nodes),
+// or input_name (if they derive from an argument).
+struct FunctionDefTensorDesc {
+ FunctionDefTensorDesc() = default;
+
+ FunctionDefTensorDesc(const string& node_name, const string& output,
+ int position);
+
+ // Parses node_name:node_output:position string into its components.
+ explicit FunctionDefTensorDesc(const string& input);
+
+ // TODO(rachelim): Add provisions to deal with special formats, like how
+ // GrapplerFunctionItem expands node output range if position is not defined
+ string full_str;
+ string node_name;
+ string node_output;
+ int position = -1;
+};
+
+// Replaces all references to `from` tensor in func's nodes' inputs and retvals
+// to `to` tensor. This is similar to `MutableGraphView::ReplaceInputs`.
+void ReplaceReferences(const string& from, const string& to, FunctionDef* func);
+
+// Adds a function output to the function def, ensuring that the output key
+// is unique, and maps to output_tensor_name in the ret dict.
+void AddFunctionOutputWithUniqueName(StringPiece prefix,
+ StringPiece output_tensor_name,
+ FunctionDef* function, DataType dt);
+
+// Adds a node to a FunctionDef.
+NodeDef* AddNode(StringPiece name, StringPiece op,
+ const std::vector<string>& inputs,
+ const std::vector<std::pair<string, AttrValue>>& attributes,
+ FunctionDef* fd);
+
+// Checks whether the function contains a node with the given name.
+bool ContainsFunctionNodeWithName(StringPiece name,
+ const FunctionDef& function);
+
+// Checks whether the function contains a node with the given op.
+bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
+
+// Checks whether the function contains an output with the given name.
+bool ContainsFunctionOutputWithName(StringPiece name,
+ const FunctionDef& function);
+
+// Returns the index of the function input with the given name or -1 if the
+// function node does not exist.
+int FindFunctionInputWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function output with the given name or -1 if the
+// function node does not exist.
+int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function node with the given name or -1 if the
+// function node does not exist.
+int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function node with the given op or -1 if the
+// function node does not exist.
+int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
+
+// Sets the function node name using the `prefix` as a prefix while guaranteeing
+// the name is unique across the functions nodes.
+void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
+ NodeDef* node);
+
+} // end namespace function_utils
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils_test.cc b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc
new file mode 100644
index 0000000000..3739e20eb1
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc
@@ -0,0 +1,164 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+namespace {
+
+TEST(FunctionDefTensorDesc, Parsing) {
+ FunctionDefTensorDesc f("Cast:y:0");
+ EXPECT_EQ(f.full_str, "Cast:y:0");
+ EXPECT_EQ(f.node_name, "Cast");
+ EXPECT_EQ(f.node_output, "y");
+ EXPECT_EQ(f.position, 0);
+
+ FunctionDefTensorDesc f2("Arg0");
+ EXPECT_EQ(f2.full_str, "Arg0");
+ EXPECT_EQ(f2.node_name, "Arg0");
+ EXPECT_EQ(f2.node_output, "");
+ EXPECT_EQ(f2.position, -1);
+}
+
+TEST(ReplaceReferencesTest, ReplaceReferencesTest) {
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer", {"arg0: int32"}, {"out: int32", "out2: int64"}, {}, {},
+ {{"out", "MapDefun:output:0"}, {"out2", "Cast:y:0"}});
+ NodeDef* derive_node =
+ AddNode("X", "Some_Op", {"MapDefun:output:0"}, {}, &outer);
+ // Check that both the input to "X" and retval of "outer" are replaced.
+ ReplaceReferences("MapDefun:output:0", "arg0", &outer);
+ EXPECT_EQ(outer.ret().at("out"), "arg0");
+ EXPECT_EQ(derive_node->input(0), "arg0");
+}
+
+TEST(FunctionUtilsTest, AddFunctionOutputWithUniqueName) {
+ FunctionDef function = test::function::XTimesTwo();
+ AddFunctionOutputWithUniqueName("y", "two", &function, DT_INT64);
+ EXPECT_TRUE(ContainsFunctionOutputWithName("y/_1", function));
+ EXPECT_EQ(function.ret().at("y/_1"), "two");
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionNodeWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_FALSE(ContainsFunctionNodeWithName(
+ "weird_name_that_should_not_be_there", function));
+ EXPECT_TRUE(ContainsFunctionNodeWithName("two", function));
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionNodeWithOp) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there",
+ function));
+ EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function));
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionOutputWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_TRUE(ContainsFunctionOutputWithName("y", function));
+ EXPECT_FALSE(ContainsFunctionOutputWithName("Add:z:0", function));
+}
+
+TEST(FunctionUtilsTest, FindFunctionNodeWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(
+ FindFunctionNodeWithName("weird_name_that_should_not_be_there", function),
+ -1);
+ EXPECT_NE(FindFunctionNodeWithName("two", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionNodeWithOp) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(
+ FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function),
+ -1);
+ EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionInputWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(FindFunctionInputWithName("x", function), 0);
+ EXPECT_EQ(FindFunctionInputWithName("not_a_name", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionOutputWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(FindFunctionOutputWithName("y", function), 0);
+ EXPECT_EQ(FindFunctionOutputWithName("Add:z:0", function), -1);
+}
+
+TEST(FunctionUtilsTest, SetUniqueFunctionNodeName) {
+ FunctionDef function = test::function::XTimesTwo();
+ NodeDef node;
+ SetUniqueFunctionNodeName("abc", &function, &node);
+ for (const NodeDef& function_node : function.node_def()) {
+ EXPECT_NE(node.name(), function_node.name());
+ }
+ auto* new_node = function.add_node_def();
+ *new_node = node;
+
+ NodeDef other;
+ SetUniqueFunctionNodeName("abc", &function, &other);
+ EXPECT_NE(other.name(), new_node->name());
+}
+
+TEST(FunctionUtilsTest, AddNodeToFunctionDef) {
+ FunctionDef func;
+ const char* op_name = "xxx";
+ AddNode(op_name, op_name, {}, {}, &func);
+
+ const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func));
+ EXPECT_EQ(node1.op(), op_name);
+ EXPECT_EQ(node1.input_size(), 0);
+ EXPECT_EQ(node1.attr_size(), 0);
+
+ const std::vector<string> inputs({"input1", "input2"});
+ AddNode("", op_name, inputs, {}, &func);
+ const NodeDef& node2 =
+ func.node_def(FindFunctionNodeWithName("xxx/_2", func));
+ EXPECT_EQ(node2.op(), op_name);
+ EXPECT_EQ(node2.attr_size(), 0);
+ EXPECT_EQ(node2.input_size(), inputs.size());
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ EXPECT_EQ(node2.input(i), inputs[i]);
+ }
+
+ AttrValue a1, a2;
+ a1.set_type(DT_INT32);
+ a2.set_type(DT_INT64);
+ const std::vector<std::pair<string, AttrValue>> attrs(
+ {{"attr1", a1}, {"attr2", a2}});
+ AddNode("", op_name, {}, attrs, &func);
+ const NodeDef& node3 =
+ func.node_def(FindFunctionNodeWithName("xxx/_3", func));
+ EXPECT_EQ(node3.op(), op_name);
+ EXPECT_EQ(node3.input_size(), 0);
+ EXPECT_EQ(node3.attr_size(), attrs.size());
+ for (size_t i = 0; i < attrs.size(); ++i) {
+ EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type());
+ }
+}
+
+} // namespace
+} // namespace function_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
index 01a78c04b0..b3bfee138f 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -407,7 +408,7 @@ void LazyConjunctionNodes(const FunctionDef& first_function,
auto* if_node = fused_function->add_node_def();
// This is guaranteed to succeed.
TF_CHECK_OK(if_builder.Finalize(if_node));
- graph_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node);
+ function_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node);
GetMutableOutputNode(fused_function, 0) = if_node->name() + ":output:0";
}
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
index d5c6466080..e667affeea 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -110,9 +111,9 @@ TEST(FusionUtilsTest, FuseFunctionWithPredicate) {
CheckUniqueNames(*fused_function);
ASSERT_TRUE(
- graph_utils::ContainsFunctionNodeWithOp("Equal", *fused_function));
+ function_utils::ContainsFunctionNodeWithOp("Equal", *fused_function));
const auto &equal_node = fused_function->node_def(
- graph_utils::FindFunctionNodeWithOp("Equal", *fused_function));
+ function_utils::FindFunctionNodeWithOp("Equal", *fused_function));
EXPECT_EQ(xtimes_two->signature().output_arg(0).name(),
fused_function->signature().output_arg(0).name());
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
new file mode 100644
index 0000000000..b2eec7220e
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
+
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_tests_utils {
+
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name) {
+ return test::function::NDef(
+ name, "MapDataset", {string(input_node_name)},
+ {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
+ {"Targuments", {}},
+ {"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<DataType>{}}});
+}
+
+NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name) {
+ return test::function::NDef(
+ name, "FilterDataset", {string(input_node_name)},
+ {{"predicate", FunctionDefHelper::FunctionRef(string(function_name))},
+ {"Targuments", {}},
+ {"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<TensorShape>{}}});
+}
+
+} // end namespace graph_tests_utils
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
new file mode 100644
index 0000000000..ca0fde997d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
@@ -0,0 +1,36 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_tests_utils {
+
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name = "XTimesTwo");
+
+NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name = "IsZero");
+
+} // end namespace graph_tests_utils
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index d4ab444036..3eaaf8fbef 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -88,6 +88,16 @@ NodeDef* AddScalarConstNodeHelper(
} // namespace
+NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) {
+ NodeDef node;
+ node.set_op("Placeholder");
+ SetUniqueGraphNodeName(node.op(), graph->GetGraph(), &node);
+ (*node.mutable_attr())["dtype"].set_type(dtype);
+ TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape();
+ shape->set_unknown_rank(false);
+ return graph->AddNode(std::move(node));
+}
+
NodeDef* AddNode(StringPiece name, StringPiece op,
const std::vector<string>& inputs,
const std::vector<std::pair<string, AttrValue>>& attributes,
@@ -108,26 +118,6 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
return graph->AddNode(std::move(node));
}
-NodeDef* AddNode(StringPiece name, StringPiece op,
- const std::vector<string>& inputs,
- const std::vector<std::pair<string, AttrValue>>& attributes,
- FunctionDef* fd) {
- NodeDef* node = fd->add_node_def();
- if (!name.empty()) {
- node->set_name(string(name));
- } else {
- SetUniqueFunctionNodeName(op, fd, node);
- }
- node->set_op(string(op));
- for (const string& input : inputs) {
- node->add_input(input);
- }
- for (auto attr : attributes) {
- (*node->mutable_attr())[attr.first] = attr.second;
- }
- return node;
-}
-
template <>
NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) {
return AddScalarConstNodeHelper(
@@ -196,6 +186,11 @@ bool Compare(const GraphDef& g1, const GraphDef& g2) {
return true;
}
+bool ContainsGraphFunctionWithName(StringPiece name,
+ const FunctionDefLibrary& library) {
+ return FindGraphFunctionWithName(name, library) != -1;
+}
+
bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) {
return FindGraphNodeWithName(name, graph) != -1;
}
@@ -204,31 +199,24 @@ bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
return FindGraphNodeWithOp(op, graph) != -1;
}
-bool ContainsGraphFunctionWithName(StringPiece name,
- const FunctionDefLibrary& library) {
- return FindGraphFunctionWithName(name, library) != -1;
-}
-
-bool ContainsFunctionNodeWithName(StringPiece name,
- const FunctionDef& function) {
- return FindFunctionNodeWithName(name, function) != -1;
-}
-
-bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
- return FindFunctionNodeWithOp(op, function) != -1;
+int FindGraphFunctionWithName(StringPiece name,
+ const FunctionDefLibrary& library) {
+ return GetFirstElementIndexWithPredicate(
+ [&name](const FunctionDef& function) {
+ return function.signature().name() == name;
+ },
+ library.function());
}
int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return GetFirstElementIndexWithPredicate(
[&name](const NodeDef& node) { return node.name() == name; },
graph.node());
- return indices.empty() ? -1 : indices.front();
}
int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return GetFirstElementIndexWithPredicate(
[&op](const NodeDef& node) { return node.op() == op; }, graph.node());
- return indices.empty() ? -1 : indices.front();
}
std::vector<int> FindAllGraphNodesWithOp(const string& op,
@@ -237,31 +225,6 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op,
[&op](const NodeDef& node) { return node.op() == op; }, graph.node());
}
-int FindGraphFunctionWithName(StringPiece name,
- const FunctionDefLibrary& library) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
- [&name](const FunctionDef& function) {
- return function.signature().name() == name;
- },
- library.function());
- return indices.empty() ? -1 : indices.front();
-}
-
-int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
- [&name](const NodeDef& node) { return node.name() == name; },
- function.node_def());
- return indices.empty() ? -1 : indices.front();
-}
-
-int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
- [&op](const NodeDef& node) { return node.op() == op; },
- function.node_def());
-
- return indices.empty() ? -1 : indices.front();
-}
-
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
if (node.input_size() == 0) return nullptr;
GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
@@ -284,17 +247,6 @@ void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
node->set_name(std::move(name));
}
-void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
- NodeDef* node) {
- string name = string(prefix);
- int id = function->node_def_size();
- while (ContainsFunctionNodeWithName(name, *function)) {
- name = strings::StrCat(prefix, "/_", id);
- ++id;
- }
- node->set_name(std::move(name));
-}
-
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
FunctionDef* function) {
string name = string(prefix);
@@ -306,6 +258,20 @@ void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
function->mutable_signature()->set_name(std::move(name));
}
+void CopyAttribute(const string& attribute_name, const NodeDef& from,
+ NodeDef* to_node) {
+ (*to_node->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
+}
+
+void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
+ const NodeDef& second, NodeDef* to_node) {
+ CopyAttribute(attribute_name, first, to_node);
+ (*to_node->mutable_attr())
+ .at(attribute_name)
+ .mutable_list()
+ ->MergeFrom(second.attr().at(attribute_name).list());
+}
+
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index 6f431c232d..5dd7819100 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -31,17 +31,29 @@ namespace tensorflow {
namespace grappler {
namespace graph_utils {
+// Returns the index of the first element in collection that fulfills predicate.
+// If no such element exists, returns -1.
+template <typename Predicate, typename Collection>
+int GetFirstElementIndexWithPredicate(const Predicate& predicate,
+ const Collection& collection) {
+ unsigned idx = 0;
+ for (auto&& element : collection) {
+ if (predicate(element)) {
+ return idx;
+ }
+ idx++;
+ }
+ return -1;
+}
+
// Adds a node to the graph.
NodeDef* AddNode(StringPiece name, StringPiece op,
const std::vector<string>& inputs,
const std::vector<std::pair<string, AttrValue>>& attributes,
MutableGraphView* graph);
-// Adds a node to a FunctionDef.
-NodeDef* AddNode(StringPiece name, StringPiece op,
- const std::vector<string>& inputs,
- const std::vector<std::pair<string, AttrValue>>& attributes,
- FunctionDef* fd);
+// Adds Placeholder node for given type.
+NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph);
// Adds a Const node with the given value to the graph.
template <typename T>
@@ -76,13 +88,6 @@ bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph);
bool ContainsGraphFunctionWithName(StringPiece name,
const FunctionDefLibrary& library);
-// Checks whether the function contains a node with the given name.
-bool ContainsFunctionNodeWithName(StringPiece name,
- const FunctionDef& function);
-
-// Checks whether the function contains a node with the given op.
-bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
-
// Checks whether the graph contains a node with the given op.
bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph);
@@ -95,14 +100,6 @@ int FindGraphNodeWithName(StringPiece name, const GraphDef& graph);
int FindGraphFunctionWithName(StringPiece name,
const FunctionDefLibrary& library);
-// Returns the index of the function node with the given name or -1 if the
-// function node does not exist.
-int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function);
-
-// Returns the index of the function node with the given op or -1 if the
-// function node does not exist.
-int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
-
// Returns the index of the first node with the given op or -1 if no such node
// exists.
int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph);
@@ -119,16 +116,21 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op,
// is unique across the graph.
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node);
-// Sets the function node name using the `prefix` as a prefix while guaranteeing
-// the name is unique across the functions nodes.
-void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
- NodeDef* node);
-
// Sets the node name using the `prefix` name as a prefix while guaranteeing the
// name is unique across the graph.
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
FunctionDef* function);
+// Copies attribute having name `attribute_name` from node `from` to node
+// `to_node`.
+void CopyAttribute(const string& attribute_name, const NodeDef& from,
+ NodeDef* to_node);
+
+// Concatenates list attribute having name `attribute_name` from `first` and
+// `second` node, setting it to `to_node`.
+void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
+ const NodeDef& second, NodeDef* to_node);
+
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
index c19ac7b880..db986542b2 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -24,6 +24,18 @@ namespace grappler {
namespace graph_utils {
namespace {
+TEST(GraphUtilsTest, GetFirstElementIndexWithPredicate) {
+ std::vector<int> vec({1, 2, 3, 4, 5, 6});
+ auto result = GetFirstElementIndexWithPredicate(
+ [](int elem) { return elem % 3 == 0; }, vec);
+
+ EXPECT_EQ(result, 2);
+
+ result = GetFirstElementIndexWithPredicate(
+ [](int elem) { return elem % 7 == 0; }, vec);
+ EXPECT_EQ(result, -1);
+}
+
TEST(GraphUtilsTest, AddScalarConstNodeBool) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
@@ -112,20 +124,6 @@ TEST(GraphUtilsTest, ContainsGraphFunctionWithName) {
ContainsGraphFunctionWithName(new_function->signature().name(), library));
}
-TEST(GraphUtilsTest, ContainsFunctionNodeWithName) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_FALSE(ContainsFunctionNodeWithName(
- "weird_name_that_should_not_be_there", function));
- EXPECT_TRUE(ContainsFunctionNodeWithName("two", function));
-}
-
-TEST(GraphUtilsTest, ContainsFunctionNodeWithOp) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there",
- function));
- EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function));
-}
-
TEST(GraphUtilsTest, ContainsNodeWithOp) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
@@ -150,22 +148,6 @@ TEST(GraphUtilsTest, FindGraphNodeWithName) {
EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1);
}
-TEST(GraphUtilsTest, FindFunctionNodeWithName) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_EQ(
- FindFunctionNodeWithName("weird_name_that_should_not_be_there", function),
- -1);
- EXPECT_NE(FindFunctionNodeWithName("two", function), -1);
-}
-
-TEST(GraphUtilsTest, FindFunctionNodeWithOp) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_EQ(
- FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function),
- -1);
- EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1);
-}
-
TEST(GraphUtilsTest, FindGraphFunctionWithName) {
FunctionDefLibrary library;
EXPECT_EQ(FindGraphFunctionWithName("new_function", library), -1);
@@ -225,21 +207,6 @@ TEST(GraphUtilsTest, SetUniqueGraphNodeName) {
EXPECT_NE(node2->name(), node3->name());
}
-TEST(GraphUtilsTest, SetUniqueFunctionNodeName) {
- FunctionDef function = test::function::XTimesTwo();
- NodeDef node;
- SetUniqueFunctionNodeName("abc", &function, &node);
- for (const NodeDef& function_node : function.node_def()) {
- EXPECT_NE(node.name(), function_node.name());
- }
- auto* new_node = function.add_node_def();
- *new_node = node;
-
- NodeDef other;
- SetUniqueFunctionNodeName("abc", &function, &other);
- EXPECT_NE(other.name(), new_node->name());
-}
-
TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
FunctionDefLibrary library;
FunctionDef* new_function = library.add_function();
@@ -251,43 +218,6 @@ TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
other_function->signature().name());
}
-TEST(GraphUtilsTest, AddNodeToFunctionDef) {
- FunctionDef func;
- const char* op_name = "xxx";
- AddNode(op_name, op_name, {}, {}, &func);
-
- const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func));
- EXPECT_EQ(node1.op(), op_name);
- EXPECT_EQ(node1.input_size(), 0);
- EXPECT_EQ(node1.attr_size(), 0);
-
- const std::vector<string> inputs({"input1", "input2"});
- AddNode("", op_name, inputs, {}, &func);
- const NodeDef& node2 =
- func.node_def(FindFunctionNodeWithName("xxx/_2", func));
- EXPECT_EQ(node2.op(), op_name);
- EXPECT_EQ(node2.attr_size(), 0);
- EXPECT_EQ(node2.input_size(), inputs.size());
- for (size_t i = 0; i < inputs.size(); ++i) {
- EXPECT_EQ(node2.input(i), inputs[i]);
- }
-
- AttrValue a1, a2;
- a1.set_type(DT_INT32);
- a2.set_type(DT_INT64);
- const std::vector<std::pair<string, AttrValue>> attrs(
- {{"attr1", a1}, {"attr2", a2}});
- AddNode("", op_name, {}, attrs, &func);
- const NodeDef& node3 =
- func.node_def(FindFunctionNodeWithName("xxx/_3", func));
- EXPECT_EQ(node3.op(), op_name);
- EXPECT_EQ(node3.input_size(), 0);
- EXPECT_EQ(node3.attr_size(), attrs.size());
- for (size_t i = 0; i < attrs.size(); ++i) {
- EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type());
- }
-}
-
TEST(GraphUtilsTest, GetInputNode) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc
new file mode 100644
index 0000000000..ce0b2db039
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc
@@ -0,0 +1,289 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+NodeDef MakeStatelessMap(const NodeDef& map_node, const NodeDef& zip_node,
+ const FunctionDef& stateless_function,
+ MutableGraphView* graph) {
+ NodeDef stateless_map;
+ graph_utils::SetUniqueGraphNodeName("stateless_map", graph->GetGraph(),
+ &stateless_map);
+
+ stateless_map.set_op("MapDataset");
+ stateless_map.add_input(zip_node.name());
+ // Add placeholders.
+ for (int i = 1; i < map_node.input_size(); i++)
+ stateless_map.add_input(map_node.input(i));
+
+ auto attr = map_node.attr().at("f");
+ *attr.mutable_func()->mutable_name() = stateless_function.signature().name();
+ *attr.mutable_func()->mutable_attr() = stateless_function.attr();
+ (*stateless_map.mutable_attr())["f"] = std::move(attr);
+
+ graph_utils::CopyAttribute("Targuments", map_node, &stateless_map);
+ for (auto key : {"output_shapes", "output_types"})
+ graph_utils::CopyAttribute(key, map_node, &stateless_map);
+
+ if (const auto* attr =
+ gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism"))
+ (*stateless_map.mutable_attr())["use_inter_op_parallelism"] = *attr;
+
+ return stateless_map;
+}
+
+NodeDef MakeRandomDataset(const NodeDef& random_uniform_node,
+ MutableGraphView* graph) {
+ NodeDef random_dataset;
+ random_dataset.set_op("RandomDataset");
+ graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->GetGraph(),
+ &random_dataset);
+
+ const auto* seed = graph_utils::AddScalarConstNode<int64>(
+ random_uniform_node.attr().at("seed").i(), graph);
+ const auto* seed2 = graph_utils::AddScalarConstNode<int64>(
+ random_uniform_node.attr().at("seed2").i(), graph);
+
+ random_dataset.add_input(seed->name());
+ random_dataset.add_input(seed2->name());
+
+ (*random_dataset.mutable_attr())["output_shapes"].mutable_list()->add_shape();
+ (*random_dataset.mutable_attr())["output_types"].mutable_list()->add_type(
+ DT_INT64);
+
+ return random_dataset;
+}
+
+NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) {
+ NodeDef batch_dataset;
+ batch_dataset.set_op("BatchDatasetV2");
+ graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->GetGraph(),
+ &batch_dataset);
+ const auto* batch_size = graph_utils::AddScalarConstNode<int64>(2, graph);
+ const auto* drop_reminder = graph_utils::AddScalarConstNode(false, graph);
+ batch_dataset.add_input(random_dataset.name());
+ batch_dataset.add_input(batch_size->name());
+ batch_dataset.add_input(drop_reminder->name());
+
+ (*batch_dataset.mutable_attr())["output_shapes"]
+ .mutable_list()
+ ->add_shape()
+ ->mutable_dim()
+ ->Add()
+ ->set_size(-1);
+ (*batch_dataset.mutable_attr())["output_types"].mutable_list()->add_type(
+ DT_INT64);
+
+ return batch_dataset;
+}
+
+NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node,
+ MutableGraphView* graph) {
+ NodeDef zip_node;
+ graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->GetGraph(),
+ &zip_node);
+
+ zip_node.set_op("ZipDataset");
+ zip_node.add_input(first_node.name());
+ zip_node.add_input(second_node.name());
+
+ for (auto key : {"output_shapes", "output_types"})
+ graph_utils::ConcatAttributeList(key, first_node, second_node, &zip_node);
+
+ (*zip_node.mutable_attr())["N"].set_i(2);
+
+ return zip_node;
+}
+
+// We need to insert our argument before the placeholders, which are the last
+// arguments.
+OpDef_ArgDef* InsertSeedArgument(OpDef* signature, int num_placeholders) {
+ int new_argument_idx = signature->input_arg_size() - num_placeholders;
+ signature->add_input_arg();
+ for (int i = signature->input_arg_size() - 1; i > new_argument_idx; i--) {
+ signature->mutable_input_arg()->SwapElements(i - 1, i);
+ }
+ auto* seed_arg = signature->mutable_input_arg(new_argument_idx);
+ seed_arg->set_name(strings::StrCat("seed_arg", new_argument_idx));
+ seed_arg->set_type(DT_INT64);
+
+ return seed_arg;
+}
+
+// Make function that uses `StatelessRandomUniform` instead of `RandomUniform`
+// to make it less statefull. The function can still be stateful, but in when
+// other stateful ops are e.g. `Assert`, then it will be parallelizable.
+const FunctionDef* MakeLessStatefulFunction(const FunctionDef& map_function,
+ bool is_stateful,
+ int num_placeholders,
+ FunctionDefLibrary* library) {
+ FunctionDef* stateless_function = library->add_function();
+ *stateless_function = map_function;
+ if (is_stateful)
+ stateless_function->mutable_signature()->set_is_stateful(is_stateful);
+ graph_utils::SetUniqueGraphFunctionName("stateless_function", library,
+ stateless_function);
+
+ auto* seed_arg = InsertSeedArgument(stateless_function->mutable_signature(),
+ num_placeholders);
+
+ auto* const random_uniform = stateless_function->mutable_node_def(
+ function_utils::FindFunctionNodeWithOp("RandomUniform",
+ *stateless_function));
+
+ // Replace RandomUniform node with StatelessRandomUniform.
+ random_uniform->set_op("StatelessRandomUniform");
+ random_uniform->add_input(seed_arg->name());
+ (*random_uniform->mutable_attr())["Tseed"].set_type(DT_INT64);
+ random_uniform->mutable_attr()->erase("seed");
+ random_uniform->mutable_attr()->erase("seed2");
+
+ return stateless_function;
+}
+// This function returns true if function is stateful and has single
+// RandomUniform op and no other stateful ops except Assert.
+// `is_stateful_after_hoisting` is set to true if RandomUniform is the only
+// stateful op and hoisting can be performed.
+bool CanHoistRandomUniform(const FunctionDef& map_function,
+ const FunctionLibraryDefinition& library,
+ bool* is_stateful_after_hoisting,
+ const NodeDef** random_uniform_op) {
+ if (!map_function.signature().is_stateful()) return false;
+ *is_stateful_after_hoisting = true;
+
+ bool have_other_stateful_ops = false;
+
+ for (const auto& node : map_function.node_def()) {
+ const OpDef* op_def;
+ TF_CHECK_OK(library.LookUpOpDef(node.op(), &op_def));
+ // Skip stateless nodes and assert, as it does not actually have a state.
+ if (!op_def->is_stateful()) continue;
+
+ if (op_def->name() == "Assert") {
+ have_other_stateful_ops = true;
+ continue;
+ }
+
+ // TODO(prazek): For now we only handle RandomUniform, we should handle
+ // RandomUniformInt as well.
+ if (op_def->name() != "RandomUniform") return false;
+
+ // TODO(prazek): For now we can only hoist single RandomUniform.
+ if (*random_uniform_op != nullptr) return false;
+
+ *random_uniform_op = &node;
+ }
+
+ if (!have_other_stateful_ops) *is_stateful_after_hoisting = false;
+
+ // Have we found single RandomUniform?
+ return *random_uniform_op != nullptr;
+}
+
+int NumberOfPlaceholders(const NodeDef& map_node) {
+ // First input of MapDataset is the argument to the function. Rest of the
+ // inputs are placeholders.
+ return map_node.input_size() - 1;
+}
+
+} // namespace
+
+Status HoistRandomUniform::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+
+ MutableGraphView graph(output);
+ std::set<string> nodes_to_delete;
+ FunctionLibraryDefinition function_library(OpRegistry::Global(),
+ item.graph.library());
+
+ auto get_map_node = [](const NodeDef& node) -> const NodeDef* {
+ // TODO(prazek): we could also handle ParallelMapDataset and
+ // MapAndBatchDataset.
+ if (node.op() == "MapDataset") return &node;
+ return nullptr;
+ };
+
+ for (const NodeDef& node : item.graph.node()) {
+ const NodeDef* map_node = get_map_node(node);
+ if (!map_node) continue;
+
+ const auto& fun = map_node->attr().at("f");
+ const FunctionDef* func = function_library.Find(fun.func().name());
+
+ const NodeDef* random_uniform_op = nullptr;
+ bool is_stateful_after_hoisting = true;
+ if (!CanHoistRandomUniform(*func, function_library,
+ &is_stateful_after_hoisting, &random_uniform_op))
+ continue;
+ const auto* random_seed_dataset =
+ graph.AddNode(MakeRandomDataset(*random_uniform_op, &graph));
+
+ const auto* batch_dataset =
+ graph.AddNode(MakeBatchTwo(*random_seed_dataset, &graph));
+
+ const NodeDef& parent_node = *graph_utils::GetInputNode(*map_node, graph);
+
+ const auto* zip_node =
+ graph.AddNode(MakeZipNode(parent_node, *batch_dataset, &graph));
+
+ const auto* stateless_func = MakeLessStatefulFunction(
+ *func, is_stateful_after_hoisting, NumberOfPlaceholders(*map_node),
+ output->mutable_library());
+
+ const auto* stateless_map = graph.AddNode(
+ MakeStatelessMap(*map_node, *zip_node, *stateless_func, &graph));
+
+ graph.ReplaceInput(*map_node, *stateless_map);
+
+ // TODO(b/116285210): we could also remove map functions from library if
+ // they are not used anymore.
+ nodes_to_delete.insert(map_node->name());
+ }
+
+ graph.DeleteNodes(nodes_to_delete);
+ return Status::OK();
+}
+
+void HoistRandomUniform::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output,
+ double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(HoistRandomUniform, "hoist_random_uniform");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h
new file mode 100644
index 0000000000..d1bcf6782d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h
@@ -0,0 +1,55 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// This optimization hoists instances of `random_uniform` out of a function
+// with the aim of making it stateless. It creates a new function that takes a
+// random seed as an extra argument and uses `stateless_random_uniform` instead
+// of `random_uniform` to make it stateless.
+// It also creates RandomDataset(seed).batch(2), which is zipped with old input
+// to the map. The batching in RandomDataset is because we need 2 seeds for
+// `stateless_random_uniform`.
+// TODO(prazek): for now only `RandomUniform` is handled, but we could handle
+// `RandomUniformInt` similarly.
+class HoistRandomUniform : public CustomGraphOptimizer {
+ public:
+ HoistRandomUniform() = default;
+ ~HoistRandomUniform() override = default;
+
+ string name() const override { return "hoist_random_uniform"; };
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc
new file mode 100644
index 0000000000..455459e3f6
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc
@@ -0,0 +1,84 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h"
+
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+TEST(HoistRandomUniform, SimpleHoisting) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"},
+ {{"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<DataType>{}}}),
+ graph_tests_utils::MakeMapNode("map1", "range", "RandomUniform"),
+ NDef("cache", "CacheDataset", {"map1", "filename"}, {})},
+ // FunctionLib
+ {
+ test::function::RandomUniform(),
+ });
+
+ HoistRandomUniform optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output));
+ const int new_map_id = graph_utils::FindGraphNodeWithOp("MapDataset", output);
+ const int zip_dataset_id =
+ graph_utils::FindGraphNodeWithOp("ZipDataset", output);
+ const int random_dataset_id =
+ graph_utils::FindGraphNodeWithOp("RandomDataset", output);
+ const int batch_random_id =
+ graph_utils::FindGraphNodeWithOp("BatchDatasetV2", output);
+ ASSERT_NE(random_dataset_id, -1);
+ ASSERT_NE(zip_dataset_id, -1);
+ ASSERT_NE(new_map_id, -1);
+ ASSERT_NE(batch_random_id, -1);
+
+ const auto& new_map = output.node(new_map_id);
+ const auto& zip = output.node(zip_dataset_id);
+ const auto& random = output.node(random_dataset_id);
+ const auto& batch = output.node(batch_random_id);
+
+ ASSERT_EQ(new_map.input_size(), 1);
+ EXPECT_EQ(new_map.input(0), zip.name());
+
+ ASSERT_EQ(zip.input_size(), 2);
+ EXPECT_EQ(zip.input(0), "range");
+ EXPECT_EQ(zip.input(1), batch.name());
+
+ ASSERT_EQ(batch.input_size(), 3);
+ EXPECT_EQ(batch.input(0), random.name());
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
index 63945b8b9e..e66766eb23 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
@@ -80,11 +80,12 @@ NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node,
// Set `f` and `Targuments` attributes.
for (auto key : {"f", "Targuments"}) {
- (*new_node.mutable_attr())[key] = map_node.attr().at(key);
+ graph_utils::CopyAttribute(key, map_node, &new_node);
}
+
// Set `output_types` and `output_shapes` attributes.
for (auto key : {"output_shapes", "output_types"}) {
- (*new_node.mutable_attr())[key] = batch_node.attr().at(key);
+ graph_utils::CopyAttribute(key, batch_node, &new_node);
}
return new_node;
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
index f1844a141c..c4868eacbb 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
@@ -41,19 +42,18 @@ NodeDef MakeFusedNode(const NodeDef& map_node,
fused_node.set_op("MapDataset");
fused_node.add_input(map_node.input(0));
- auto copy_attribute = [](const string& attribute_name, const NodeDef& from,
- NodeDef* to) {
- (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
- };
-
auto attr = map_node.attr().at("f");
attr.mutable_func()->set_name(fused_function.signature().name());
(*fused_node.mutable_attr())["f"] = std::move(attr);
- copy_attribute("Targuments", map_node, &fused_node);
+ graph_utils::CopyAttribute("Targuments", map_node, &fused_node);
for (auto key : {"output_shapes", "output_types"})
- copy_attribute(key, map_node, &fused_node);
+ graph_utils::CopyAttribute(key, map_node, &fused_node);
+
+ if (const auto* attr =
+ gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism"))
+ (*fused_node.mutable_attr())["use_inter_op_parallelism"] = *attr;
// Add the predicate output attributes.
(*fused_node.mutable_attr())["output_types"]
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
index f029a093fa..6e6da37d7c 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -27,24 +28,8 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
namespace {
-
-NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "MapDataset", {string(input_node_name)},
- {{"f", FunctionDefHelper::FunctionRef("XTimesTwo")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
-
-NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "FilterDataset", {string(input_node_name)},
- {{"predicate", FunctionDefHelper::FunctionRef("IsZero")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
+using graph_tests_utils::MakeFilterNode;
+using graph_tests_utils::MakeMapNode;
TEST(MapAndFilterFusionTest, FuseMapAndFilter) {
using test::function::NDef;
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
index a78ecb09f7..bd943342e8 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
@@ -40,24 +41,31 @@ NodeDef MakeFusedNode(const NodeDef& parent_map_node, const NodeDef& map_node,
NodeDef fused_node;
graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(),
&fused_node);
-
fused_node.set_op("MapDataset");
fused_node.add_input(parent_map_node.input(0));
- auto copy_attribute = [](const string& attribute_name, const NodeDef& from,
- NodeDef* to) {
- (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
- };
-
auto attr = parent_map_node.attr().at("f");
*attr.mutable_func()->mutable_name() = fused_function.signature().name();
(*fused_node.mutable_attr())["f"] = std::move(attr);
- copy_attribute("Targuments", parent_map_node, &fused_node);
-
+ graph_utils::CopyAttribute("Targuments", parent_map_node, &fused_node);
for (auto key : {"output_shapes", "output_types"})
- copy_attribute(key, map_node, &fused_node);
+ graph_utils::CopyAttribute(key, map_node, &fused_node);
+ auto value_or_false = [](const AttrValue* attr) {
+ if (!attr) return false;
+ return attr->b();
+ };
+
+ const auto* first_parallelism =
+ gtl::FindOrNull(parent_map_node.attr(), "use_inter_op_parallelism");
+ const auto* second_parallelism =
+ gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism");
+ // Some graphs cannot execute with use_inter_op_parallelism=False, so we need
+ // to set it to true if one of the ops have it set to true.
+ if (value_or_false(first_parallelism) || value_or_false(second_parallelism)) {
+ (*fused_node.mutable_attr())["use_inter_op_parallelism"].set_b(true);
+ }
return fused_node;
}
@@ -123,8 +131,8 @@ Status MapFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
// fusion.
TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_function));
- // TODO(prazek): we could also remove map functions from library if they
- // are not used anymore.
+ // TODO(b/116285210): we could also remove map functions from library if
+ // they are not used anymore.
nodes_to_delete.insert(parent_map_node->name());
nodes_to_delete.insert(map_node->name());
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
index b25dfbd0b8..8889f9dddd 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -28,14 +29,7 @@ namespace tensorflow {
namespace grappler {
namespace {
-NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "MapDataset", {string(input_node_name)},
- {{"f", FunctionDefHelper::FunctionRef("XTimesTwo")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
+using graph_tests_utils::MakeMapNode;
TEST(MapFusionTest, FuseTwoMapNodesIntoOne) {
using test::function::NDef;
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
index 305325e434..782c9f48b7 100644
--- a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
@@ -84,9 +84,6 @@ Status MapParallelization::Optimize(Cluster* cluster, const GrapplerItem& item,
auto* parallel_map = graph.AddNode(MakeParallelMap(*map_node, &graph));
graph.ReplaceInput(*map_node, *parallel_map);
-
- // TODO(prazek): we could also remove map functions from library if they
- // are not used anymore.
nodes_to_delete.insert(map_node->name());
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
index b2a5d9b6af..9fdfe8af30 100644
--- a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -28,16 +28,7 @@ namespace tensorflow {
namespace grappler {
namespace {
-NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
- StringPiece function_name) {
- return test::function::NDef(
- name, "MapDataset", {string(input_node_name)},
- {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
-
+using graph_tests_utils::MakeMapNode;
const char stateless_fun_name[] = "XTimesTwo";
const char stateful_fun_name[] = "RandomUniform";
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index a019b77eb7..32ab912619 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/map_vectorization.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@@ -24,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -33,15 +35,11 @@ namespace tensorflow {
namespace grappler {
namespace {
-void CopyAttribute(const string& attr_name, const NodeDef& from, NodeDef* to) {
- (*to->mutable_attr())[attr_name] = from.attr().at(attr_name);
-}
-
-FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
+// Returns a FunctionDef containing a MapDefun op that wraps the original
+// function.
+FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
const FunctionDef& orig_func,
FunctionDefLibrary* library) {
- // If we decide to use a different method of vectorization, we can just
- // swap out this part.
FunctionDef* vectorized_func = library->add_function();
// Function inputs and outputs are the same as original, just
// with different shapes.
@@ -52,14 +50,14 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
// Add MapDefun node
NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Add();
map_defun_node->set_op("MapDefun");
- graph_utils::SetUniqueFunctionNodeName(map_defun_node->op(), vectorized_func,
- map_defun_node);
+ function_utils::SetUniqueFunctionNodeName(map_defun_node->op(),
+ vectorized_func, map_defun_node);
// Set attrs and inputs
for (const string& k : {"f", "output_types", "output_shapes"}) {
// Function, output types and (unbatched) shapes are the same as the
// original map node.
- CopyAttribute(k, map_node, map_defun_node);
+ graph_utils::CopyAttribute(k, map_node, map_defun_node);
}
// Get types of input arguments from original map function
@@ -81,6 +79,30 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
return vectorized_func;
}
+FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
+ const FunctionDef& orig_func,
+ FunctionDefLibrary* library) {
+ // Vectorizes orig_func naively by wrapping in a MapDefun op, then performing
+ // efficient vectorization with VectorizeMapDefun.
+ FunctionDef* vectorized_func =
+ CreateMapDefunWrapper(map_node, orig_func, library);
+ NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Mutable(0);
+ DCHECK_EQ(map_defun_node->op(), "MapDefun");
+
+ // Create a copy of the original function so that we can mutate it, and
+ // attach that to the map defun node.
+ FunctionDef* map_defun_fn = library->add_function();
+ *map_defun_fn = orig_func;
+ graph_utils::SetUniqueGraphFunctionName(orig_func.signature().name(), library,
+ map_defun_fn);
+ (*map_defun_node->mutable_attr())["f"].mutable_func()->set_name(
+ map_defun_fn->signature().name());
+
+ vectorization_utils::VectorizeMapDefun(vectorized_func, map_defun_fn,
+ map_defun_node);
+ return vectorized_func;
+}
+
bool IsOutputShapesFullyDefined(const NodeDef& node) {
auto* shapes_attr = gtl::FindOrNull(node.attr(), "output_shapes");
if (shapes_attr == nullptr) return false;
@@ -169,13 +191,16 @@ NodeDef MakeNewMapNode(const NodeDef& old_map_node,
}
// Set attrs
- CopyAttribute("Targuments", old_map_node, &map_node);
+ graph_utils::CopyAttribute("Targuments", old_map_node, &map_node);
auto& func_attr = (*map_node.mutable_attr())["f"];
func_attr.mutable_func()->set_name(vectorized_func.signature().name());
for (auto key : {"output_shapes", "output_types"}) {
- CopyAttribute(key, old_batch_node, &map_node);
+ graph_utils::CopyAttribute(key, old_batch_node, &map_node);
}
+
+ (*map_node.mutable_attr())["use_inter_op_parallelism"].set_b(true);
+
return map_node;
}
diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
index a26f1000a3..cf5a19bab1 100644
--- a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
+++ b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
@@ -33,25 +33,27 @@ namespace {
bool IsTakeAll(const NodeDef& take_node, const GraphView& graph) {
if (take_node.op() != "TakeDataset") return false;
- const NodeDef& count_node = *graph.GetNode(take_node.input(1));
+ const auto& count_node = *graph.GetNode(take_node.input(1));
+ if (count_node.op() != "Const") return false;
// We are looking only for 'take' with negative count.
return count_node.attr().at("value").tensor().int64_val(0) < 0;
}
+bool IsConstNodeWithValue(const NodeDef& node, int value) {
+ if (node.op() != "Const") return false;
+ return node.attr().at("value").tensor().int64_val(0) == value;
+}
+
bool IsSkipNone(const NodeDef& skip_node, const GraphView& graph) {
if (skip_node.op() != "SkipDataset") return false;
-
- const NodeDef& count_node = *graph.GetNode(skip_node.input(1));
// We are looking only for skip(0) nodes.
- return count_node.attr().at("value").tensor().int64_val(0) == 0;
+ return IsConstNodeWithValue(*graph.GetNode(skip_node.input(1)), 0);
}
bool IsRepeatOne(const NodeDef& repeat_node, const GraphView& graph) {
if (repeat_node.op() != "RepeatDataset") return false;
-
- const NodeDef& count_node = *graph.GetNode(repeat_node.input(1));
// We are looking only for repeat(1) nodes.
- return count_node.attr().at("value").tensor().int64_val(0) == 1;
+ return IsConstNodeWithValue(*graph.GetNode(repeat_node.input(1)), 1);
}
bool IsNoOp(const NodeDef& node, const GraphView& graph) {
diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc
index f445e75aa7..be1a66df75 100644
--- a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc
@@ -43,6 +43,14 @@ NodeDef *MakeUnaryNode(StringPiece node_type, int count, string input_node,
GetCommonAttributes(), graph);
}
+NodeDef *MakeUnaryNonConstNode(StringPiece node_type, string input_node,
+ MutableGraphView *graph) {
+ NodeDef *node_count = graph_utils::AddScalarPlaceholder(DT_INT32, graph);
+ return graph_utils::AddNode("", node_type,
+ {std::move(input_node), node_count->name()},
+ GetCommonAttributes(), graph);
+}
+
NodeDef *MakeCacheNode(string input_node, MutableGraphView *graph) {
NodeDef *node_filename =
graph_utils::AddScalarConstNode<StringPiece>("", graph);
@@ -205,6 +213,41 @@ INSTANTIATE_TEST_CASE_P(
::testing::Values(*kTakeNode, *kSkipNode,
*kRepeatNode)));
+struct NoOpPlaceholdersTest
+ : ::testing::TestWithParam<std::tuple<string, string>> {};
+
+TEST_P(NoOpPlaceholdersTest, NonConstNoOpNode) {
+ GrapplerItem item;
+ MutableGraphView graph(&item.graph);
+
+ static_assert(std::tuple_size<NodesTypes>::value == 2,
+ "Make sure to include everything in the test");
+ const std::vector<string> noop_nodes = {std::get<0>(GetParam()),
+ std::get<1>(GetParam())};
+ NodeDef *range_node = MakeRangeNode(&graph);
+ std::vector<string> nodes_to_keep;
+ nodes_to_keep.reserve(noop_nodes.size());
+ NodeDef *previous = range_node;
+
+ for (const auto &noop_node : noop_nodes) {
+ NodeDef *node = MakeUnaryNonConstNode(noop_node, previous->name(), &graph);
+ nodes_to_keep.push_back(node->name());
+ previous = node;
+ }
+
+ NoOpElimination optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ for (const auto &noop_node_name : nodes_to_keep)
+ EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName(noop_node_name, output));
+}
+
+INSTANTIATE_TEST_CASE_P(
+ DoNotRemovePlaceholders, NoOpPlaceholdersTest,
+ ::testing::Combine(
+ ::testing::Values("TakeDataset", "SkipDataset", "RepeatDataset"),
+ ::testing::Values("TakeDataset", "SkipDataset", "RepeatDataset")));
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
index cb0ff670e8..99c4afa634 100644
--- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
@@ -64,7 +64,7 @@ Status ShuffleAndRepeatFusion::Optimize(Cluster* cluster,
// Set `output_types` and `output_shapes` attributes.
for (auto key : {"output_shapes", "output_types"}) {
- (*new_node.mutable_attr())[key] = repeat_node.attr().at(key);
+ graph_utils::CopyAttribute(key, repeat_node, &new_node);
}
return new_node;
};
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
new file mode 100644
index 0000000000..1462cb234d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
@@ -0,0 +1,69 @@
+package(
+ default_visibility = ["//visibility:private"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
+
+VECTORIZER_DEPS = [
+ ":vectorizer_registry",
+ "//tensorflow/core/grappler/optimizers/data:function_utils",
+] + tf_protos_all()
+
+cc_library(
+ name = "vectorizer",
+ hdrs = ["vectorizer.h"],
+ deps = [
+ "//tensorflow/core:lib",
+ ] + tf_protos_all(),
+)
+
+cc_library(
+ name = "vectorizer_registry",
+ srcs = ["vectorizer_registry.cc"],
+ hdrs = ["vectorizer_registry.h"],
+ deps = [
+ ":vectorizer",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
+ name = "cast_vectorizer",
+ srcs = ["cast_vectorizer.cc"],
+ deps = VECTORIZER_DEPS,
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "unpack_vectorizer",
+ srcs = ["unpack_vectorizer.cc"],
+ deps = VECTORIZER_DEPS,
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "vectorization",
+ hdrs = ["vectorizer_registry.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":cast_vectorizer",
+ ":unpack_vectorizer",
+ ":vectorizer",
+ ":vectorizer_registry",
+ ],
+)
+
+tf_cc_test(
+ name = "vectorizer_registry_test",
+ srcs = ["vectorizer_registry_test.cc"],
+ deps = [
+ ":vectorizer_registry",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ] + tf_protos_all(),
+)
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
new file mode 100644
index 0000000000..c1739737a0
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
@@ -0,0 +1,54 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+class CastVectorizer : public Vectorizer {
+ public:
+ Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
+ FunctionDef* outer_scope,
+ std::map<string, string>* conversion_map) override {
+ if (inputs.size() != 1) {
+ return errors::Internal("Cast op should only have one input.");
+ }
+
+ // Add new Cast node
+ NodeDef* new_cast_node = outer_scope->add_node_def();
+ *new_cast_node = node;
+ new_cast_node->clear_name();
+ function_utils::SetUniqueFunctionNodeName(
+ strings::StrCat("vectorized/", node.name()), outer_scope,
+ new_cast_node);
+ new_cast_node->set_input(0, inputs[0]);
+
+ // Add the output mapping to conversion map
+ (*conversion_map)[strings::StrCat(node.name(), ":y:0")] =
+ strings::StrCat(new_cast_node->name(), ":y:0");
+
+ return Status::OK();
+ }
+};
+
+REGISTER_VECTORIZER("Cast", CastVectorizer);
+
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
new file mode 100644
index 0000000000..776d3179c5
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
@@ -0,0 +1,61 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+class UnpackVectorizer : public Vectorizer {
+ public:
+ Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
+ FunctionDef* outer_scope,
+ std::map<string, string>* conversion_map) override {
+ if (inputs.size() != 1) {
+ return errors::Internal("Unpack op should only have one input.");
+ }
+
+ // Add new Unpack node
+ NodeDef* new_unpack_node = outer_scope->add_node_def();
+ *new_unpack_node = node;
+ new_unpack_node->clear_name();
+ function_utils::SetUniqueFunctionNodeName(
+ strings::StrCat("vectorized/", node.name()), outer_scope,
+ new_unpack_node);
+
+ // Increment "axis" attr by 1:
+ (*new_unpack_node->mutable_attr())["axis"].set_i(
+ node.attr().at("axis").i() + 1);
+ new_unpack_node->set_input(0, inputs[0]);
+
+ // Add the output mappings to conversion map
+ int num = new_unpack_node->attr().at("num").i();
+ for (int i = 0; i < num; ++i) {
+ (*conversion_map)[strings::StrCat(node.name(), ":output:", i)] =
+ strings::StrCat(new_unpack_node->name(), ":output:", i);
+ }
+
+ return Status::OK();
+ }
+};
+
+REGISTER_VECTORIZER("Unpack", UnpackVectorizer);
+
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
new file mode 100644
index 0000000000..d341dbba7d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
+
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+// Interface for vectorization of TensorFlow operations. See `CastVectorizer`
+// for an example.
+class Vectorizer {
+ public:
+ virtual ~Vectorizer() {}
+
+ // Vectorizes an operation, `node`, by adding operation(s) to `outer_scope`
+ // that produce the same vector output(s) as executing `node`'s op
+ // on elements of the vector inputs, and adding mappings to `conversion_map`
+ // from old output tensor names to new (vectorized) output tensor names.
+ // The new node(s) collectively have the same number of inputs and outputs as
+ // the node being converted, and use the tensor names in `inputs` as their
+ // inputs.
+ virtual Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
+ FunctionDef* outer_scope,
+ std::map<string, string>* conversion_map) = 0;
+};
+
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
new file mode 100644
index 0000000000..a6551e36ac
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+VectorizerRegistry* VectorizerRegistry::Global() {
+ static VectorizerRegistry* registry = new VectorizerRegistry;
+ return registry;
+}
+
+Vectorizer* VectorizerRegistry::Get(const string& op_type) {
+ auto found = vectorizers_.find(op_type);
+ if (found == vectorizers_.end()) {
+ return nullptr;
+ }
+ return found->second.get();
+}
+
+void VectorizerRegistry::Register(const string& op_type,
+ std::unique_ptr<Vectorizer> vectorizer) {
+ auto existing = Get(op_type);
+ CHECK_EQ(existing, nullptr)
+ << "Vectorizer for op type: " << op_type << " already registered";
+ vectorizers_.insert(std::pair<const string&, std::unique_ptr<Vectorizer>>(
+ op_type, std::move(vectorizer)));
+}
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
new file mode 100644
index 0000000000..16159d47ca
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
@@ -0,0 +1,75 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_REGISTRY_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_REGISTRY_H_
+
+#include <functional>
+#include <map>
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+// A global VectorizerRegistry is used to hold all the vectorizers.
+class VectorizerRegistry {
+ public:
+ // Returns a pointer to a global VectorizerRegistry object.
+ static VectorizerRegistry* Global();
+
+ // Returns a pointer to a vectorizer that can vectorize an op for the op type.
+ Vectorizer* Get(const string& op_type);
+
+ // Registers a vectorizer that can vectorize an op for the given op type.
+ void Register(const string& op_type, std::unique_ptr<Vectorizer> vectorizer);
+
+ private:
+ std::map<string, std::unique_ptr<Vectorizer>> vectorizers_;
+};
+
+namespace vectorizer_registration {
+
+class VectorizerRegistration {
+ public:
+ VectorizerRegistration(const string& op_type,
+ std::unique_ptr<Vectorizer> vectorizer) {
+ VectorizerRegistry::Global()->Register(op_type, std::move(vectorizer));
+ }
+};
+
+} // namespace vectorizer_registration
+
+#define REGISTER_VECTORIZER(op_type, vectorizer) \
+ REGISTER_VECTORIZER_UNIQ_HELPER(__COUNTER__, op_type, vectorizer)
+
+#define REGISTER_VECTORIZER_UNIQ_HELPER(ctr, op_type, vectorizer) \
+ REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer)
+
+#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \
+ static ::tensorflow::grappler::vectorization_utils:: \
+ vectorizer_registration::VectorizerRegistration \
+ vectorizer_registration_##ctr( \
+ op_type, \
+ ::std::unique_ptr< \
+ ::tensorflow::grappler::vectorization_utils::Vectorizer>( \
+ new vectorizer()))
+
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_REGISTRY_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
new file mode 100644
index 0000000000..86e303564b
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
@@ -0,0 +1,50 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+class TestVectorizer : public Vectorizer {
+ public:
+ Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
+ FunctionDef* outer_scope,
+ std::map<string, string>* conversion_map) override {
+ return Status::OK();
+ }
+};
+
+REGISTER_VECTORIZER("test_op", TestVectorizer);
+
+TEST(TestVectorizer, TestTestVectorizer) {
+ EXPECT_EQ(VectorizerRegistry::Global()->Get("nonexistent"), nullptr);
+
+ auto vectorizer = VectorizerRegistry::Global()->Get("test_op");
+ EXPECT_NE(vectorizer, nullptr);
+
+ FunctionDef function;
+ NodeDef node;
+ std::map<string, string> conversion_map;
+ EXPECT_TRUE(vectorizer->Vectorize(node, {}, &function, &conversion_map).ok());
+}
+
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
new file mode 100644
index 0000000000..cb56b65985
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -0,0 +1,292 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
+
+#include "absl/strings/str_join.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/functions.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/scanner.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+using function_utils::FunctionDefTensorDesc;
+
+namespace {
+
+void AddMapDefunOutput(FunctionDef* map_defun_fn, NodeDef* map_defun_node,
+ const string& output_retval, const DataType t) {
+ // Set to unknown shape
+ TensorShapeProto tensor_shape_proto;
+ PartialTensorShape().AsProto(&tensor_shape_proto);
+
+ function_utils::AddFunctionOutputWithUniqueName(
+ "vectorized_out", output_retval, map_defun_fn, t);
+
+ *(*map_defun_node->mutable_attr())["output_shapes"]
+ .mutable_list()
+ ->add_shape() = tensor_shape_proto;
+ (*map_defun_node->mutable_attr())["output_types"].mutable_list()->add_type(t);
+}
+
+void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node, int output_position) {
+ DCHECK_LT(output_position, map_defun_fn->signature().output_arg_size())
+ << "Trying to remove output that doesn't exist. Output number: "
+ << output_position;
+
+ int num_later_outputs =
+ map_defun_fn->signature().output_arg_size() - output_position - 1;
+
+ // Remove from map_defun_fn's ret dict and output args
+ map_defun_fn->mutable_ret()->erase(
+ map_defun_fn->signature().output_arg(output_position).name());
+ map_defun_fn->mutable_signature()->mutable_output_arg()->DeleteSubrange(
+ output_position, 1);
+
+ // Renumber outputs that come after
+ for (int i = 0; i < num_later_outputs; ++i) {
+ function_utils::ReplaceReferences(
+ strings::StrCat(map_defun_node->name(),
+ ":output:", output_position + i + 1),
+ strings::StrCat(map_defun_node->name(),
+ ":output:", output_position + i),
+ outer_scope);
+ }
+ map_defun_node->mutable_attr()
+ ->at("output_shapes")
+ .mutable_list()
+ ->mutable_shape()
+ ->DeleteSubrange(output_position, 1);
+ map_defun_node->mutable_attr()
+ ->at("output_types")
+ .mutable_list()
+ ->mutable_type()
+ ->ExtractSubrange(output_position, 1, nullptr);
+}
+
+int FindOutputToConvert(const FunctionDef& function,
+ const std::set<string>& unconvertible,
+ FunctionDefTensorDesc* f) {
+ for (int i = function.signature().output_arg_size() - 1; i >= 0; --i) {
+ const string& ret_key = function.signature().output_arg(i).name();
+ *f = FunctionDefTensorDesc(function.ret().at(ret_key));
+
+ if (unconvertible.find(f->node_name) == unconvertible.end()) {
+ return i;
+ }
+ }
+ return -1;
+}
+
+// Helper class that vectorizes the body of a MapDefun node, adding new
+// operations to the graph that collectively compute the same value as what
+// running the MapDefun function on slices of the input would produce.
+// Each instance of the class encapsulates all the data necessary to vectorize a
+// MapDefun op in place.
+class Vectorization {
+ public:
+ Vectorization(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node)
+ : outer_scope_(outer_scope),
+ map_defun_fn_(map_defun_fn),
+ map_defun_node_(map_defun_node) {}
+
+ // Repeatedly tries to convert outputs of map_defun_fn_ into new nodes in
+ // the outer_scope_, until there are no convertible outputs remaining.
+ // This method is idempotent.
+ void Vectorize();
+
+ private:
+ // Vectorizes the map defun function's output at output_position
+ Status ConvertOutput(int output_position, const FunctionDefTensorDesc& desc);
+ // Given a descriptor of the original output tensor, gets a string
+ // corresponding to the converted output tensor.
+ Status ConvertOutputHelper(const FunctionDefTensorDesc& output_desc,
+ string* converted);
+ Status AddConversionMappingFromInput(
+ const FunctionDefTensorDesc& output_desc);
+
+ // Adds mappings from node's outputs tensors to converted output tensors,
+ // creating the necessary new node(s). Generally, the steps to convert an op
+ // are:
+ // 1) Promote the inputs of the op inputs to outputs of the map_defun_fn_,
+ // and modify map_defun_node_ attrs accordingly
+ // 2) Create new node(s) in outer_scope_ that act on batched input tensors.
+ // These operations collectively compute the same value as what running
+ // the original operation on slices of the input tensors would produce.
+ // For example, a Cast op in MapDefun translates to a Cast op in
+ // outer_scope_, since the vectorized version of Cast is itself.
+ // 3) Set inputs of new node(s) to the corresponding converted inputs (that
+ // are now outputs of map_defun_node_)
+ // 4) For each output of the old node, add the mapping of output strings to
+ // the conversion map (eg "Cast:y:0" -> "Vectorize/Cast:y:0")
+ Status AddConversionMappingFromOp(const NodeDef& node,
+ const FunctionDefTensorDesc& output_desc);
+
+ // Maps a tensor name to the name of the corresponding vectorized tensor. For
+ // example, "Cast:y:0" -> "Vectorize/Cast:y:0"
+ std::map<string, string> conversion_map_;
+ // Unconvertible node names
+ std::set<string> unconvertible_;
+
+ FunctionDef* outer_scope_;
+ FunctionDef* map_defun_fn_;
+ NodeDef* map_defun_node_;
+};
+
+Status Vectorization::AddConversionMappingFromOp(
+ const NodeDef& node, const FunctionDefTensorDesc& output_desc) {
+ for (const string& input_name : node.input()) {
+ if (IsControlInput(input_name)) {
+ return errors::InvalidArgument(
+ "Vectorizing outputs with control inputs is currently not "
+ "supported.");
+ }
+ }
+
+ // TODO(rachelim): Have some mechanism for registering converters and some
+ // uniform, simpler way to represent them.
+
+ DataTypeVector types;
+ const OpDef* op_def = nullptr;
+ TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def));
+ TF_RETURN_IF_ERROR(InputTypesForNode(node, *op_def, &types));
+
+ std::vector<string> promoted_inputs;
+ promoted_inputs.reserve(node.input_size());
+ for (int i = 0; i < node.input_size(); ++i) {
+ promoted_inputs.push_back(strings::StrCat(
+ map_defun_node_->name(),
+ ":output:", map_defun_fn_->signature().output_arg_size() + i));
+ }
+
+ auto vectorizer = VectorizerRegistry::Global()->Get(node.op());
+ if (vectorizer == nullptr) {
+ return errors::Unimplemented("No vectorizer registered for op: ",
+ node.op());
+ }
+
+ TF_RETURN_IF_ERROR(vectorizer->Vectorize(node, promoted_inputs, outer_scope_,
+ &conversion_map_));
+
+ // If we get here, the conversion was successful, so we promote the inputs
+ // of the ops to MapDefun outputs.
+ for (int i = 0; i < types.size(); ++i) {
+ AddMapDefunOutput(map_defun_fn_, map_defun_node_, node.input(i), types[i]);
+ }
+
+ return Status::OK();
+}
+
+Status Vectorization::AddConversionMappingFromInput(
+ const FunctionDefTensorDesc& output_desc) {
+ int input_index = function_utils::FindFunctionInputWithName(
+ output_desc.node_name, *map_defun_fn_);
+ if (input_index == -1) {
+ return errors::Internal("Cannot convert non-existent input.");
+ }
+
+ conversion_map_[output_desc.full_str] = map_defun_node_->input(input_index);
+ return Status::OK();
+}
+
+Status Vectorization::ConvertOutputHelper(
+ const FunctionDefTensorDesc& output_desc, string* converted) {
+ // It's possible the output already has a mapping, if it comes from a node
+ // that has already been converted.
+ if (auto found = gtl::FindOrNull(conversion_map_, output_desc.full_str)) {
+ *converted = *found;
+ return Status::OK();
+ }
+
+ int index = function_utils::FindFunctionNodeWithName(output_desc.node_name,
+ *map_defun_fn_);
+ if (index == -1) { // The output comes from an input
+ TF_RETURN_IF_ERROR(AddConversionMappingFromInput(output_desc));
+ } else {
+ TF_RETURN_IF_ERROR(AddConversionMappingFromOp(
+ map_defun_fn_->node_def(index), output_desc));
+ }
+ *converted = conversion_map_.at(output_desc.full_str);
+ return Status::OK();
+}
+
+Status Vectorization::ConvertOutput(int output_position,
+ const FunctionDefTensorDesc& output_desc) {
+ string converted_output_name;
+ TF_RETURN_IF_ERROR(ConvertOutputHelper(output_desc, &converted_output_name));
+
+ // Remove the old output and make everything that referenced it point
+ // to the new string
+ function_utils::ReplaceReferences(
+ strings::StrCat(map_defun_node_->name(), ":output:", output_position),
+ converted_output_name, outer_scope_);
+ RemoveMapDefunOutput(outer_scope_, map_defun_fn_, map_defun_node_,
+ output_position);
+
+ return Status::OK();
+}
+
+void Vectorization::Vectorize() {
+ while (true) {
+ FunctionDefTensorDesc desc;
+ int output_position =
+ FindOutputToConvert(*map_defun_fn_, unconvertible_, &desc);
+ if (output_position == -1) break;
+
+ if (!ConvertOutput(output_position, desc).ok()) {
+ unconvertible_.insert(desc.node_name);
+ }
+ }
+
+ // If we've converted all the outputs of the MapDefun function, we no longer
+ // need the MapDefun node and can delete it.
+ if (map_defun_fn_->signature().output_arg_size() == 0) {
+ outer_scope_->mutable_node_def()->DeleteSubrange(
+ function_utils::FindFunctionNodeWithName(map_defun_node_->name(),
+ *outer_scope_),
+ 1);
+ }
+
+ if (!unconvertible_.empty()) {
+ VLOG(2) << "The following nodes could not be converted: ["
+ << absl::StrJoin(unconvertible_, ", ") << "].";
+ }
+}
+} // namespace
+
+void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node) {
+ Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize();
+}
+
+} // end namespace vectorization_utils
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
new file mode 100644
index 0000000000..bb405faa77
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
@@ -0,0 +1,90 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
+
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+// Given a function, `map_defun_fn`, that is mapped across some input vector
+// elements via a MapDefun operation, `VectorizeMapDefun` attempts to
+// vectorize the MapDefun by "lifting" operations from the `map_defun_fn` to the
+// `outer_scope`; that is, replacing `map_defun_fn` operations with new
+// `outer_scope` operations that produce the same vector output(s) as executing
+// the `map_defun_fn` operations on elements of vector input(s) would. If all
+// `map_defun_fn` operations are successfully lifted, `map_defun_node` is
+// eliminated from `outer_scope` altogether. However, if some operations cannot
+// be lifted, and this vectorization only succeeds partially, `map_defun_node`
+// remains to be used for operations that were not lifted.
+//
+// Example:
+// If the input to the `VectorizeMapDefun` function is a MapDefun
+// whose `map_defun_fn` performs the Cast operation, the vectorization will
+// eliminate the MapDefun. This is because the Cast operation supports
+// any tensor shape and can thus be lifted to the `outer_scope`.
+//
+// Before:
+//
+//
+// outer_scope +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | map_defun_fn +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// outer_scope +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node);
+
+} // end namespace vectorization_utils
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
new file mode 100644
index 0000000000..e129fa9237
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -0,0 +1,600 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+namespace {
+
+NodeDef* AddCastNode(const string& name, const std::vector<string>& inputs,
+ DataType src, DataType dst, bool truncate,
+ FunctionDef* fn) {
+ NodeDef* node = function_utils::AddNode(name, "Cast", inputs, {}, fn);
+ graph_transforms::SetNodeAttr("SrcT", src, node);
+ graph_transforms::SetNodeAttr("DstT", dst, node);
+ graph_transforms::SetNodeAttr("Truncate", truncate, node);
+ return node;
+}
+
+NodeDef* AddUnstackNode(const string& name, const std::vector<string>& inputs,
+ DataType t, int axis, int num, FunctionDef* fn) {
+ NodeDef* node = function_utils::AddNode(name, "Unpack", inputs, {}, fn);
+ graph_transforms::SetNodeAttr("T", t, node);
+ graph_transforms::SetNodeAttr("axis", axis, node);
+ graph_transforms::SetNodeAttr("num", num, node);
+ return node;
+}
+
+NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs,
+ const std::vector<DataType>& t_arguments,
+ const std::vector<DataType>& output_types,
+ const std::vector<TensorShape>& output_shapes,
+ const string& function_name, FunctionDef* fn) {
+ NameAttrList func;
+ func.set_name(function_name);
+ NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn);
+ graph_transforms::SetNodeAttr("Targuments", t_arguments, node);
+ graph_transforms::SetNodeAttr("output_types", output_types, node);
+ graph_transforms::SetNodeAttr("output_shapes", output_shapes, node);
+ graph_transforms::SetNodeAttr("f", func, node);
+ return node;
+}
+
+// TODO(rachelim): Use FunctionDefHelper::Create instead
+FunctionDef CreateFunction(
+ StringPiece name, const std::vector<std::pair<string, DataType>>& inputs,
+ const std::vector<std::pair<string, DataType>>& outputs,
+ const std::map<string, string>& rets) {
+ FunctionDef func;
+ auto* signature = func.mutable_signature();
+ signature->set_name(string(name));
+ for (const auto& x : inputs) {
+ auto* arg_def = signature->add_input_arg();
+ arg_def->set_name(x.first);
+ arg_def->set_type(x.second);
+ }
+ for (const auto& x : outputs) {
+ auto* arg_def = signature->add_output_arg();
+ arg_def->set_name(x.first);
+ arg_def->set_type(x.second);
+ }
+ for (const auto& x : rets) {
+ (*func.mutable_ret())[x.first] = x.second;
+ }
+
+ return func;
+}
+
+TEST(FunctionDefInputDescTest, ConstructedCorrectly) {}
+
+// Before:
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// | +-----------+ Arg0 +---+ Arg1 +----+ |
+// | | +---+--+ +---+--+ | |
+// | | | | | |
+// | | MapDefun +---v--+ +---v--+ | |
+// | +-----------+ Ret0 +---+ Ret1 +----+ |
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+//
+// After:
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | | | |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+ {{"ret0", "arg0"}, {"ret1", "arg1"}});
+ FunctionDef outer = CreateFunction(
+ "outer_function", {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+ {{"mapdefun", DT_INT32}, {"mapdefun_0", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"ret0", "ret1"}, {DT_INT32, DT_INT32}, {DT_INT32, DT_INT32},
+ {{}, {}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ EXPECT_EQ(outer.ret().at("mapdefun"), "ret0");
+ EXPECT_EQ(outer.ret().at("mapdefun_0"), "ret1");
+}
+
+// Before:
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// | +-----------+ Arg0 +---+ Arg1 +----+ |
+// | | +---+--+ +---+--+ | |
+// | | | | | |
+// | | +------+ | +---v--+ | |
+// | | |Const | | | Op0 | | |
+// | | +---v--+ | +---+--+ | |
+// | | | | | | |
+// | | | +---v--+ +---v--+ | |
+// | | +---| XOp1 | | XOp2 | | |
+// | | +---+--+ +---+--+ | |
+// | | | | | |
+// | | MapDefun +---v--+ +---v--+ | |
+// | +-----------+ Ret0 +---+ Ret1 +----+ |
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+// where XOp1 and XOp2 are not convertible.
+//
+// After:
+//
+// No change because the ops are not convertible.
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+ {{"ret0", "XOp1:output:0"}, {"ret1", "XOp2:output:0"}});
+ NodeDef* x_op1 =
+ function_utils::AddNode("XOp1", "XOp1", {"const", "arg0"}, {}, &inner);
+ CHECK_NOTNULL(x_op1);
+
+ NodeDef* x_op2 = function_utils::AddNode("XOp2", "XOp2", {"op1"}, {}, &inner);
+ CHECK_NOTNULL(x_op2);
+
+ FunctionDef outer = CreateFunction(
+ "outer_function", {{"x", DT_INT32}, {"y", DT_INT32}},
+ {{"mapdefun", DT_INT32}, {"mapdefun_0", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"x", "y"}, {DT_INT32, DT_INT32}, {DT_INT32, DT_INT32},
+ {{}, {}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ FunctionDef outer_copy(outer);
+ FunctionDef inner_copy(inner);
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ // They should be unchanged
+ EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
+ EXPECT_TRUE(FunctionDefsEqual(inner_copy, inner));
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
+ NodeDef* cast_op =
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ CHECK_NOTNULL(cast_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT64}},
+ {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ const NodeDef& cast_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ EXPECT_EQ(cast_node.input(0), "x");
+ EXPECT_EQ(outer.ret().at("mapdefun"),
+ strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(outer.node_def_size(), 1);
+}
+
+// Before:
+//
+// +------+
+// +---------------+ Arg0 +-------------------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +---------------+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +----------+ | |
+// | | | | | |
+// | | MapDefun +---v--+ +---v--+ | |
+// | +-----------+ Ret0 +---+ Ret1 +----+ |
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +-------------------+
+// | +---+--+ |
+// | | |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +----------+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) {
+ // Tests that behavior is correct when an output is used more than once.
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT64}, {"ret1", DT_INT64}},
+ {{"ret0", "Cast:y:0"}, {"ret1", "Cast:y:0"}});
+ NodeDef* cast_op =
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ CHECK_NOTNULL(cast_op);
+
+ FunctionDef outer = CreateFunction(
+ "outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT64}, {"mapdefun_0", DT_INT64}},
+ {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64, DT_INT64},
+ {{}, {}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ const NodeDef& cast_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ EXPECT_EQ(cast_node.input(0), "x");
+ EXPECT_EQ(outer.ret().at("mapdefun"),
+ strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(outer.node_def_size(), 1);
+}
+
+// Before:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +------------------+ Arg0 +------------------+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v---+ num=3 | |
+// | | |Unstack| axis=0 | |
+// | | ++--+--++ | |
+// | | | | | | |
+// | | +----+ | +-------+ | |
+// | | | | | | |
+// | | MapDefun +---v--+ +-v----+ +--v---+ | |
+// | +----------+ Ret0 +--+ Ret1 +--+ Ret2 +------+ |
+// | +---+--+ +--+---+ +--+---+ |
+// | | | | |
+// | +---v--+ +--v---+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+//
+// After:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v---+ num=3 |
+// | |Unstack| axis=1 |
+// | ++--+--++ |
+// | | | | |
+// | +----+ | +-------+ |
+// | | | | |
+// | | | | |
+// | +---v--+ +-v----+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) {
+ FunctionDef inner = CreateFunction(
+ "inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}, {"ret2", DT_INT32}},
+ {{"ret0", "MyUnstack:output:0"},
+ {"ret1", "MyUnstack:output:1"},
+ {"ret2", "MyUnstack:output:2"}});
+ NodeDef* unstack_op =
+ AddUnstackNode("MyUnstack", {"arg0"}, DT_INT32, 0, 3, &inner);
+ CHECK_NOTNULL(unstack_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT32},
+ {"mapdefun_0", DT_INT32},
+ {"mapdefun_1", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"},
+ {"mapdefun_0", "MapDefun:output:1"},
+ {"mapdefun_1", "MapDefun:output:2"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"x"}, {DT_INT32}, {DT_INT32, DT_INT32, DT_INT32},
+ {{1}, {1}, {1}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ const NodeDef& unpack_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ EXPECT_EQ(unpack_node.input(0), "x");
+ EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
+ EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
+ EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
+ EXPECT_EQ(outer.ret().at("mapdefun"),
+ strings::StrCat(unpack_node.name(), ":output:0"));
+ EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ strings::StrCat(unpack_node.name(), ":output:1"));
+ EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ strings::StrCat(unpack_node.name(), ":output:2"));
+ EXPECT_EQ(outer.node_def_size(), 1);
+}
+
+// Before:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +------------------+ Arg0 +------------------+ |
+// | | +---+--+ | |
+// | | | | |
+// | | +---+--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +---v---+ num=3 | |
+// | | |Unstack| axis=0 | |
+// | | ++--+--++ | |
+// | | | | | | |
+// | | +----+ | +-------+ | |
+// | | | | | | |
+// | | MapDefun +---v--+ +-v----+ +--v---+ | |
+// | +----------+ Ret0 +--+ Ret1 +--+ Ret2 +------+ |
+// | +---+--+ +--+---+ +--+---+ |
+// | | | | |
+// | +---v--+ +--v---+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+//
+// After:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | +---+--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v---+ num=3 |
+// | |Unstack| axis=1 |
+// | ++--+--++ |
+// | | | | |
+// | +----+ | +-------+ |
+// | | | | |
+// | | | | |
+// | +---v--+ +-v----+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
+ FunctionDef inner = CreateFunction(
+ "inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}, {"ret2", DT_INT32}},
+ {{"ret0", "MyUnstack:output:0"},
+ {"ret1", "MyUnstack:output:1"},
+ {"ret2", "MyUnstack:output:2"}});
+ NodeDef* cast_op =
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ CHECK_NOTNULL(cast_op);
+ NodeDef* unstack_op =
+ AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner);
+ CHECK_NOTNULL(unstack_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT32},
+ {"mapdefun_0", DT_INT32},
+ {"mapdefun_1", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"},
+ {"mapdefun_0", "MapDefun:output:1"},
+ {"mapdefun_1", "MapDefun:output:2"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"x"}, {DT_INT32}, {DT_INT32, DT_INT32, DT_INT32},
+ {{1}, {1}, {1}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ const NodeDef& cast_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ EXPECT_EQ(cast_node.input(0), "x");
+ const NodeDef& unpack_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
+ EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
+ EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
+
+ EXPECT_EQ(outer.ret().at("mapdefun"),
+ strings::StrCat(unpack_node.name(), ":output:0"));
+ EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ strings::StrCat(unpack_node.name(), ":output:1"));
+ EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ strings::StrCat(unpack_node.name(), ":output:2"));
+ EXPECT_EQ(outer.node_def_size(), 2);
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +---+--+ | |
+// | | +---------+ | |
+// | | +---v--+ | | |
+// | | |Print | | | |
+// | | +---+--+ | | |
+// | | : +---v--+ | |
+// | | ::::::> Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// No change because we don't deal with control inputs for now.
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
+ // The attrs aren't relevant
+ NodeDef* print_op =
+ function_utils::AddNode("Print", "Print", {"arg0", "arg0"}, {}, &inner);
+ CHECK_NOTNULL(print_op);
+ NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64,
+ false, &inner);
+ CHECK_NOTNULL(cast_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT64}},
+ {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ FunctionDef outer_copy(outer);
+ FunctionDef inner_copy(inner);
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ // They should be unchanged
+ EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
+}
+
+// TODO(rachelim): More test cases when we get around to implementing them:
+// [] A badly defined converter, e.g. doesn't produce nodes that have the
+// same number of outputs/inputs as the nodes to be converted
+// [] Converter where the 'converted' form has multiple nodes.
+// [] Case with dependent nodes, e.g. ops with const inputs that are
+// broadcasted.
+// [] Python-side tests to actually run the functions to make sure
+// they work.
+
+} // namespace
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper.cc b/tensorflow/core/grappler/optimizers/debug_stripper.cc
index 9701a038d0..800160e649 100644
--- a/tensorflow/core/grappler/optimizers/debug_stripper.cc
+++ b/tensorflow/core/grappler/optimizers/debug_stripper.cc
@@ -38,7 +38,7 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item,
// be optimized away by dependency optimizer.
for (string& inp : *node.mutable_input()) {
if (!IsControlInput(inp)) {
- inp = AsControlDependency(inp);
+ inp = AsControlDependency(NodeName(inp));
}
}
} else if (IsCheckNumerics(node) || IsPrint(node)) {
@@ -54,7 +54,7 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item,
// input.
for (size_t i = 1; i < node.input_size(); ++i) {
if (!IsControlInput(node.input(i))) {
- *node.mutable_input(i) = AsControlDependency(node.input(i));
+ *node.mutable_input(i) = AsControlDependency(NodeName(node.input(i)));
}
}
}
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
index 96ceee791f..affd2d51c2 100644
--- a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
+++ b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
@@ -43,6 +43,35 @@ TEST_F(DebugStripperTest, OutputEqualToInput) {
CompareGraphs(item.graph, output);
}
+TEST_F(DebugStripperTest, StripAssertOnTwoOutputs) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
+ ops::Placeholder::Shape({6}));
+ auto split =
+ ops::Split(s.WithOpName("split"), /*axis=*/0, input, /*num_split=*/2);
+ Output x = split[0];
+ Output y = split[1];
+ Output ge = ops::GreaterEqual(s.WithOpName("GreaterEqual"), x, y);
+ auto assert = ops::Assert(s.WithOpName("Assert"), ge, {x, y});
+ Output add = ops::Add(
+ s.WithOpName("add").WithControlDependencies({assert.operation}), x, y);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ DebugStripper optimizer;
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ for (const NodeDef& node : output.node()) {
+ for (const string& input : node.input()) {
+ if (IsControlInput(input)) {
+ EXPECT_EQ(input.find(':'), -1);
+ }
+ }
+ }
+}
+
TEST_F(DebugStripperTest, StripAssertFromGraph) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc
index eeea269fb0..2c36c9b7b3 100644
--- a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc
@@ -32,8 +32,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-REGISTER_GRAPH_OPTIMIZER(ExperimentalImplementationSelector);
-
Status ExperimentalImplementationSelector::LoadFunctions(
const GraphDef& graph) {
lib_info_.reset(new FunctionLibraryApiInfo);
@@ -43,8 +41,20 @@ Status ExperimentalImplementationSelector::LoadFunctions(
Status ExperimentalImplementationSelector::MaybeOptimizeFunctionCall(
NodeDef* node_def) const {
- const FunctionApiInfo* info = lib_info_->GetApiInfo(node_def->op());
- if (info == nullptr) {
+ // There are two ways of calling functions:
+ // 1. By specifying an op name as a function name, or
+ // 2. Via the @defun functional interface, where the real function name
+ // appear as the attribute with type func.
+ std::vector<string> function_attribute_names;
+ for (const auto& attr : node_def->attr()) {
+ if (attr.second.has_func() &&
+ lib_info_->GetApiInfo(attr.second.func().name()) != nullptr) {
+ function_attribute_names.emplace_back(attr.first);
+ }
+ }
+
+ if (function_attribute_names.empty() &&
+ lib_info_->GetApiInfo(node_def->op()) == nullptr) {
// A regular op, or a function which has no interface.
return Status::OK();
}
@@ -58,17 +68,25 @@ Status ExperimentalImplementationSelector::MaybeOptimizeFunctionCall(
DeviceNameUtils::ParsedName parsed_name;
DeviceNameUtils::ParseLocalName(device, &parsed_name);
- string best_function_name;
- lib_info_->GetBestImplementation(node_def->op(), parsed_name.type,
- &best_function_name);
- if (node_def->op() != best_function_name) {
- // The current implementation is not the best, swap the op to the best one.
- // There will be duplicates in the graph and they will be pruned by other
- // grappler plugin since no other node is using their output as inputs.
- // TODO(scottzhu): Update the tf.eager.defun to register functions without
- // having to call them with input data. That will reduce the graph size and
- // save the work for prune them.
- node_def->set_op(best_function_name);
+ for (const auto& attr_name : function_attribute_names) {
+ string function_name = node_def->attr().at(attr_name).func().name();
+ string best_function_name;
+ lib_info_->GetBestImplementation(function_name, parsed_name.type,
+ &best_function_name);
+ if (function_name != best_function_name) {
+ node_def->mutable_attr()
+ ->find(attr_name)
+ ->second.mutable_func()
+ ->set_name(best_function_name);
+ }
+ }
+ if (lib_info_->GetApiInfo(node_def->op()) != nullptr) {
+ string best_function_name;
+ lib_info_->GetBestImplementation(node_def->op(), parsed_name.type,
+ &best_function_name);
+ if (node_def->op() != best_function_name) {
+ node_def->set_op(best_function_name);
+ }
}
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
index 2368e577c2..3f1ebefac6 100644
--- a/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
@@ -45,9 +45,8 @@ TEST_F(ExperimentalImplementationSelectorTest, NoUpdate) {
GrapplerItem item;
CHECK(fake_input.NextItem(&item));
- std::unique_ptr<CustomGraphOptimizer> optimizer =
- CustomGraphOptimizerRegistry::CreateByNameOrNull(
- "ExperimentalImplementationSelector");
+ std::unique_ptr<CustomGraphOptimizer> optimizer(
+ new ExperimentalImplementationSelector);
ASSERT_NE(nullptr, optimizer);
TF_ASSERT_OK(optimizer->Init());
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 7ed4a67333..e18a5f21d2 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -23,11 +23,13 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
#include "tensorflow/core/grappler/optimizers/debug_stripper.h"
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h"
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
+#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
#include "tensorflow/core/grappler/optimizers/remapper.h"
#include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
@@ -104,6 +106,7 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
MK_OPT("scoped_allocator",
new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
cfg_.scoped_allocator_opts()));
+ MK_OPT("small_op", new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
return std::unique_ptr<GraphOptimizer>();
}
@@ -132,6 +135,9 @@ Status MetaOptimizer::InitializeOptimizers(
if (cfg_.remapping() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
}
+ if (cfg_.pin_to_host_optimization() != RewriterConfig::OFF) {
+ optimizers->push_back(MakeUnique<PinToHostOptimizer>());
+ }
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
optimizers->push_back(
MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
@@ -166,11 +172,12 @@ Status MetaOptimizer::InitializeOptimizers(
optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>(
cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts()));
}
- return InitializeCustomGraphOptimizers(optimizers);
+ return InitializeCustomGraphOptimizers(std::set<string>(), optimizers);
}
Status MetaOptimizer::InitializeOptimizersByName(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
+ std::set<string> initialized_custom_optimizers;
for (const string& optimizer_name : cfg_.optimizers()) {
auto optimizer = MakeNewOptimizer(optimizer_name);
if (optimizer) {
@@ -184,26 +191,54 @@ Status MetaOptimizer::InitializeOptimizersByName(
if (custom_optimizer) {
VLOG(2) << "Registered custom graph optimizer: " << optimizer_name;
- TF_RETURN_IF_ERROR(custom_optimizer->Init());
+ TF_RETURN_IF_ERROR(custom_optimizer->Init(
+ GetCustomGraphOptimizerConfig(optimizer_name)));
optimizers->push_back(std::move(custom_optimizer));
+ initialized_custom_optimizers.insert(optimizer_name);
} else {
VLOG(2) << "Can't register an optimizer by name: " << optimizer_name;
}
}
- return InitializeCustomGraphOptimizers(optimizers);
+ return InitializeCustomGraphOptimizers(initialized_custom_optimizers,
+ optimizers);
}
Status MetaOptimizer::InitializeCustomGraphOptimizers(
+ const std::set<string>& pre_initialized_optimizers,
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
for (const auto& optimizer_config : cfg_.custom_optimizers()) {
- auto custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
- optimizer_config.name());
+ if (pre_initialized_optimizers.find(optimizer_config.name()) !=
+ pre_initialized_optimizers.end()) {
+ continue;
+ }
+ // Initialize the ExperimentalImplementationSelector here instead of
+ // CustomizeOptimizer registry, due the static link issue in TensorRT for
+ // double registry.
+ // TODO(laigd): Remove this hack and change it back to use the registry once
+ // the duplicate static import issue is fixed.
+ std::unique_ptr<CustomGraphOptimizer> custom_optimizer;
+ if (optimizer_config.name() == "ExperimentalImplementationSelector") {
+ custom_optimizer.reset(new ExperimentalImplementationSelector());
+ } else {
+ custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
+ optimizer_config.name());
+ }
if (custom_optimizer) {
VLOG(2) << "Registered custom configurable graph optimizer: "
<< optimizer_config.name();
TF_RETURN_IF_ERROR(custom_optimizer->Init(&optimizer_config));
optimizers->push_back(std::move(custom_optimizer));
} else {
+ // If there are no custom optimizers with given name, try to initalize a
+ // default optimizer. This way, custom configurable optimizers can be
+ // mixed with default optimizers in any order.
+ auto optimizer = MakeNewOptimizer(optimizer_config.name());
+ if (optimizer) {
+ VLOG(2) << "Registered default graph optimizer: "
+ << optimizer_config.name();
+ optimizers->push_back(std::move(optimizer));
+ continue;
+ }
VLOG(2) << "Can't register an optimizer by name: "
<< optimizer_config.name();
}
@@ -211,6 +246,16 @@ Status MetaOptimizer::InitializeCustomGraphOptimizers(
return Status::OK();
}
+const RewriterConfig::CustomGraphOptimizer*
+MetaOptimizer::GetCustomGraphOptimizerConfig(const string& name) const {
+ for (const auto& config : cfg_.custom_optimizers()) {
+ if (config.name() == name) {
+ return &config;
+ }
+ }
+ return nullptr;
+}
+
Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes
@@ -341,7 +386,7 @@ Status MetaOptimizer::RunOptimizer(
Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
- LOG(INFO) << "Starting optimization for grappler item: " << item.id;
+ VLOG(1) << "Starting optimization for grappler item: " << item.id;
optimization_results_.clear();
// 1. Optimize main graph
@@ -457,6 +502,7 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
cfg.debug_stripper() == RewriterConfig::ON ||
cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
+ cfg.pin_to_host_optimization() != RewriterConfig::OFF ||
!cfg.optimizers().empty() || !cfg.custom_optimizers().empty();
}
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h
index 831c5e37c0..99a0a33ffa 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h
@@ -54,7 +54,11 @@ class MetaOptimizer : public GraphOptimizer {
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
// Initialize active optimizers from RewriterConfig.custom_optimizers.
Status InitializeCustomGraphOptimizers(
+ const std::set<string>& pre_initialized_optimizers,
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
+ // Returns the config for a custom graph optimizer. Null if none was found.
+ const RewriterConfig::CustomGraphOptimizer* GetCustomGraphOptimizerConfig(
+ const string& name) const;
// Run optimization pass over a single GrapplerItem. Meta optimizer might run
// multiple such passes: 1) for the main graph 2) for the function library
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index e74e0f7501..c477c4d4b1 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -71,6 +71,17 @@ class TestGraphOptimizer : public TestOptimizer {
REGISTER_GRAPH_OPTIMIZER(TestGraphOptimizer);
+class TestOptimizerWithParams : public TestOptimizer {
+ public:
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ CHECK(config != nullptr);
+ return Status::OK();
+ }
+};
+
+REGISTER_GRAPH_OPTIMIZER(TestOptimizerWithParams);
+
class MetaOptimizerTest : public GrapplerTest {};
TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
@@ -90,6 +101,25 @@ TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
EXPECT_TRUE(TestOptimizer::IsOptimized());
}
+TEST_F(MetaOptimizerTest, RunsCustomOptimizerWithParams) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ TestOptimizer::SetOptimized(false);
+ RewriterConfig rewriter_config;
+ rewriter_config.add_optimizers("TestOptimizerWithParams");
+ auto* custom_config = rewriter_config.add_custom_optimizers();
+ custom_config->set_name("TestOptimizerWithParams");
+ (*custom_config->mutable_parameter_map())["foo"] = AttrValue();
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ EXPECT_TRUE(TestOptimizer::IsOptimized());
+}
+
TEST_F(MetaOptimizerTest, RunsCustomOptimizerAndCustomGraphOptimizer) {
TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
GrapplerItem item;
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
new file mode 100644
index 0000000000..2190d38937
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
@@ -0,0 +1,264 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
+#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace internal {
+
+// TODO(williamchan): Change this constant to be something smarter, maybe
+// dynamically determined.
+constexpr int64 kTensorMaxSize = 64;
+
+// Find KernelDef for `node`.
+Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) {
+ // Try find KernelDef for node.device, else GPU or CPU.
+ for (const DeviceType& device :
+ {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}) {
+ Status s = FindKernelDef(device, node, kdef, nullptr);
+ if (s.ok()) {
+ return Status::OK();
+ }
+ }
+
+ return errors::NotFound("Could not find KernelDef for op: ", node.op());
+}
+
+// Check if all node's inputs are pinned to CPU memory.
+bool AreAllNodeInputsPinnedToHost(const GraphView& graph, const NodeDef& node) {
+ // Loop through all the inputs excluding the controlling nodes.
+ for (const GraphView::OutputPort& fanin : graph.GetFanins(node, false)) {
+ // Check if (the fanin) op's device is on CPU.
+ if (str_util::StrContains(fanin.node->device(), DEVICE_CPU)) {
+ continue;
+ }
+
+ // Check if (the fanin) op's output port is pinned to HostMemory.
+ const OpDef* fanin_odef = nullptr;
+ Status s = OpRegistry::Global()->LookUpOpDef(fanin.node->op(), &fanin_odef);
+ if (!s.ok()) {
+ LOG(INFO) << "Could not find OpDef for : " << fanin.node->op();
+ return false;
+ }
+
+ const int output_arg_id =
+ OpOutputPortIdToArgId(*fanin.node, *fanin_odef, fanin.port_id);
+ if (output_arg_id < 0) {
+ LOG(WARNING) << "Invalid port: " << fanin.port_id << "!\n"
+ << node.DebugString() << "\n"
+ << fanin.node->DebugString() << "\n"
+ << fanin_odef->DebugString();
+ return false;
+ }
+
+ const KernelDef* fanin_kdef = nullptr;
+ s = TryFindKernelDef(*fanin.node, &fanin_kdef);
+ if (!s.ok()) {
+ LOG(INFO) << "Could not find KernelDef for : " << fanin.node->op();
+ return false;
+ }
+
+ bool fanin_pinned = false;
+ for (const string& host_memory_arg : fanin_kdef->host_memory_arg()) {
+ if (fanin_odef->output_arg(output_arg_id).name() == host_memory_arg) {
+ fanin_pinned = true;
+ break;
+ }
+ }
+
+ if (!fanin_pinned) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) {
+ // Check if Tensor is integer and small size.
+
+ // Check type to be int32 or int64.
+ if (prop.dtype() != DataType::DT_INT32 &&
+ prop.dtype() != DataType::DT_INT64) {
+ return false;
+ }
+
+ // Check size known and small.
+ const int64 size = NumCoefficients(prop.shape());
+ if (size < 0 || size > kTensorMaxSize) {
+ return false;
+ }
+
+ return true;
+}
+
+bool AreAllNodeInputsAndOutputsIntsAndSmall(const GraphProperties& properties,
+ const NodeDef& node) {
+ for (const auto& prop : properties.GetInputProperties(node.name())) {
+ if (!IsTensorIntegerAndSmall(prop)) {
+ return false;
+ }
+ }
+
+ for (const auto& prop : properties.GetOutputProperties(node.name())) {
+ if (!IsTensorIntegerAndSmall(prop)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+string TryFindHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, const string& device) {
+ // Force this node onto the CPU.
+ if (device.empty() && has_device_cpu) {
+ return "/device:CPU:0";
+ } else if (str_util::StrContains(device, DEVICE_GPU)) {
+ // Sometimes the cluster can have:
+ // devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
+ // and we need to handle them properly.
+ for (const auto& device_match :
+ {std::pair<string, string>("GPU", "CPU:0"),
+ std::pair<string, string>("/device", "/device:CPU:0")}) {
+ const string device_host =
+ strings::StrCat(device.substr(0, device.rfind(device_match.first)),
+ device_match.second);
+ if (devices.find(device_host) != devices.end()) {
+ return device_host;
+ }
+ }
+ }
+
+ // We couldn't find an appropriate Host device, return original device.
+ return device;
+}
+
+bool IsTPUGraphDef(const GraphDef& def) {
+ for (const auto& node : def.node()) {
+ if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
+ node.op() == "TPUPartitionedCall") {
+ return true;
+ }
+ }
+ return false;
+}
+
+// All the nodes that should be blacklisted and not swapped.
+bool IsBlacklisted(const NodeDef& node) { return IsCollective(node); }
+} // end namespace internal
+
+Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ *optimized_graph = item.graph;
+
+ // Skip all TPU graphs.
+ if (internal::IsTPUGraphDef(*optimized_graph)) {
+ return Status::OK();
+ }
+
+ GraphProperties properties(item);
+ bool has_properties = false;
+ GraphView graph(optimized_graph);
+
+ gtl::FlatSet<string> devices;
+ if (cluster) {
+ const std::vector<string> device_names = cluster->GetDeviceNames();
+ devices.insert(device_names.begin(), device_names.end());
+ } else {
+ devices = {"/device:CPU:0"};
+ }
+
+ const bool has_device_cpu = devices.find("/device:CPU:0") != devices.end();
+
+ // Topologically sort the graph, so that we traverse the nodes in order. This
+ // will help us discover producer->consumer chains of Host ops.
+ TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
+
+ // All the Const nodes, and their original devices in topological order.
+ std::vector<std::pair<NodeDef*, string>> const_nodes;
+
+ for (auto& node : *optimized_graph->mutable_node()) {
+ // Check if node already on CPU.
+ if (str_util::StrContains(node.device(), DEVICE_CPU)) {
+ continue;
+ }
+
+ // Skip these node types.
+ if (internal::IsBlacklisted(node)) {
+ continue;
+ }
+
+ // Check the node can be run on CPU.
+ Status s = FindKernelDef(DEVICE_CPU, node, nullptr, nullptr);
+ if (!s.ok()) {
+ continue;
+ }
+
+ // Check all input's are pinned to CPU.
+ if (!internal::AreAllNodeInputsPinnedToHost(graph, node)) {
+ continue;
+ }
+
+ if (!has_properties) {
+ // This is an expensive call, call it lazily.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ has_properties = true;
+ }
+
+ // Check all inputs and outputs are integers and small.
+ if (!internal::AreAllNodeInputsAndOutputsIntsAndSmall(properties, node)) {
+ continue;
+ }
+
+ if (IsConstant(node)) {
+ const_nodes.emplace_back(&node, node.device());
+ }
+ // Try and swap the device to Host.
+ node.set_device(
+ internal::TryFindHostDevice(devices, has_device_cpu, node.device()));
+ }
+
+ // Traverse all `const_nodes`, and map them back to GPU greedily.
+ for (auto& it : const_nodes) {
+ NodeDef* node = it.first;
+ const string& device = it.second;
+
+ // Check all the consumers of this node, if any of them are on the original
+ // device, swap this node back onto the original device.
+ for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) {
+ if (fanout.node->device() == device) {
+ node->set_device(device);
+ break;
+ }
+ }
+ }
+ return Status::OK();
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
new file mode 100644
index 0000000000..d557a03463
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
+
+#include <unordered_set>
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace internal {
+// Try and find an appropriate Host device in `devices` given `device`.
+string TryFindHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, const string& device);
+} // end namespace internal
+
+// Optimize TensorFlow ops that should be swapped into the CPU to avoid
+// excessive cpu<->gpu memcpy/sync.
+//
+// TODO(williamchan): The current heuristic will swap any small integer Const to
+// CPU. This may cause a problem cpu->cpu->gpu wherein the original behaviour of
+// gpu->gpu->gpu may have been better/faster. We should probably fix this.
+class PinToHostOptimizer : public GraphOptimizer {
+ public:
+ PinToHostOptimizer() : opt_level_(RewriterConfig::DEFAULT) {}
+ explicit PinToHostOptimizer(RewriterConfig::Toggle opt_level)
+ : opt_level_(opt_level) {}
+
+ ~PinToHostOptimizer() override {}
+
+ string name() const override { return "pin_to_host_optimizer"; };
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override {}
+
+ private:
+ RewriterConfig::Toggle opt_level_;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
new file mode 100644
index 0000000000..173cb3fe3c
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
@@ -0,0 +1,194 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+class PinToHostOptimizerTest : public GrapplerTest {};
+
+TEST_F(PinToHostOptimizerTest, TryFindHostDevice) {
+ gtl::FlatSet<string> devices = {};
+ EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC"));
+
+ devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
+ EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"),
+ "/device:CPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"),
+ "/device:CPU:0");
+
+ devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
+ "/device:XLA_CPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
+ "/device:XLA_CPU:0");
+
+ devices = {"/device:XLA_GPU:0"};
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
+ "/device:XLA_GPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
+ "/device:XLA_GPU:*");
+}
+
+TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024});
+ Output c = ops::Shape(s.WithOpName("c"), a);
+ Output d = ops::Const(s.WithOpName("d"), 0, {1});
+ Output e = ops::ReduceProd(s.WithOpName("e"), c, d);
+
+ GrapplerItem item;
+ item.fetch = {"a", "c", "d", "e"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto tensors = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); ++i) {
+ test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
+ }
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "a" || node.name() == "c") {
+ EXPECT_TRUE(node.device().empty());
+ } else if (node.name() == "d" || node.name() == "e") {
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+ }
+ ++found;
+ }
+ EXPECT_EQ(found, 4);
+}
+
+TEST_F(PinToHostOptimizerTest, TopologicalSort) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024});
+ Output c = ops::Shape(s.WithOpName("c"), a);
+ Output d = ops::Const(s.WithOpName("d"), 0, {1});
+ Output e = ops::ReduceProd(s.WithOpName("e"), c, d);
+
+ GrapplerItem item;
+ item.fetch = {"a", "c", "d", "e"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ // Reverse the graph, and hence rely on the optimizer to sort it.
+ std::reverse(item.graph.mutable_node()->begin(),
+ item.graph.mutable_node()->end());
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto tensors = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); ++i) {
+ test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
+ }
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "a" || node.name() == "c") {
+ EXPECT_TRUE(node.device().empty());
+ } else if (node.name() == "d" || node.name() == "e") {
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+ }
+ ++found;
+ }
+ EXPECT_EQ(found, 4);
+}
+
+TEST_F(PinToHostOptimizerTest, NoSwap) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ // `b` should be too big to swap, consequently `c` should not be swapped.
+ // PinToHostOptimizer should then detect that `a` should not be swapped.
+ Output a = ops::Const(s.WithOpName("a"), 1, {1, 1});
+ Output b = ops::Const(s.WithOpName("b"), 1, {1, 1024 * 1024});
+ Output c = ops::MatMul(s.WithOpName("c"), a, b);
+
+ GrapplerItem item;
+ item.fetch = {"a", "b", "c"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto tensors = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); ++i) {
+ test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
+ }
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ EXPECT_TRUE(node.device().empty());
+ ++found;
+ }
+ EXPECT_EQ(found, 3);
+}
+
+TEST_F(PinToHostOptimizerTest, PortIdToArgId) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3});
+ ops::ShapeN b(s.WithOpName("b"), {a, a, a});
+
+ GrapplerItem item;
+ item.fetch = {"a", "b"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto tensors = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); ++i) {
+ test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
+ }
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+ ++found;
+ }
+ EXPECT_EQ(found, 2);
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 03e36a7b9c..9ada8b7ff9 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -168,11 +168,12 @@ void AddBatchNormNodes(GraphDef* optimized_graph, const NodeDef& fused_node) {
Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
GraphDef* optimized_graph) {
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ bool inferred_properties = false;
GraphView graph(const_cast<GraphDef*>(&item.graph));
// During inference, most of the inputs to FusedBatchNorm are constant, and we
// can therefore replace the op with a much cheaper set of primitives.
+ optimized_graph->mutable_node()->Reserve(item.graph.node_size());
for (const NodeDef& node : item.graph.node()) {
if (node.op() == "FusedBatchNorm" || node.op() == "FusedBatchNormV2") {
bool optimizable = (node.attr().count("T") == 0 ||
@@ -181,6 +182,11 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
!node.attr().at("is_training").b());
if (optimizable) {
int const_inputs = 0;
+ if (!inferred_properties) {
+ // Infer properties lazily in case they are not needed.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ inferred_properties = true;
+ }
const auto& props = properties.GetInputProperties(node.name());
for (const auto& prop : props) {
if (prop.has_value()) {
@@ -218,7 +224,7 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
void Remapper::Feedback(Cluster* /*cluster*/, const GrapplerItem& /*item*/,
const GraphDef& /*optimized_graph*/,
double /*result*/) {
- // Nothing to do for ArithmeticOptimizer.
+ // Nothing to do for RemapperOptimizer.
}
} // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
index caa0b7b0cb..6ccb1cd783 100644
--- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
@@ -20,10 +20,9 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
-
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -34,7 +33,7 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
*optimized_graph = item.graph;
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ bool inferred_properties = false;
GraphView graph(optimized_graph);
// The product of all the dimensions in a tensor shape can be expressed more
@@ -56,6 +55,11 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
const GraphView::OutputPort reduce_indices =
graph.GetRegularFanin(GraphView::InputPort(fanout.node, 1));
+ if (!inferred_properties) {
+ // Infer properties lazily in case they are not needed.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ inferred_properties = true;
+ }
const auto& prop =
properties.GetOutputProperties(reduce_indices.node->name());
if (prop.size() < reduce_indices.port_id) {
@@ -93,6 +97,11 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
if (!IsSize(*input1.node) || !IsSize(*input2.node)) {
continue;
}
+ if (!inferred_properties) {
+ // Infer properties lazily in case they are not needed.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ inferred_properties = true;
+ }
const auto& prop1 = properties.GetInputProperties(input1.node->name());
const auto& prop2 = properties.GetInputProperties(input2.node->name());
if (prop1.size() != 1 || prop2.size() != 1) {
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index 153785d3b4..5867d01324 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils.h"
+#include <iterator>
#include <memory>
#include <queue>
#include <vector>
@@ -24,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -154,17 +156,6 @@ bool IsControlInput(const string& name) {
return !name.empty() && name[0] == '^';
}
-string NodeName(const string& name) {
- int position;
- return ParseNodeName(name, &position);
-}
-
-int NodePosition(const string& name) {
- int position;
- ParseNodeNameAsStringPiece(name, &position);
- return position;
-}
-
string AddPrefixToNodeName(const string& name, const string& prefix,
const string& delimiter) {
if (!name.empty()) {
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index 20dbeea2cf..95126d470c 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -29,7 +29,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
-#include "tensorflow/core/lib/strings/scanner.h"
namespace tensorflow {
namespace grappler {
@@ -102,39 +101,92 @@ bool IsControlInput(const string& name);
// True iff 'name1' and 'name2' refer to the same input.
bool IsSameInput(const string& name1, const string& name2);
+// Returns the trailing position number (or zero if no number is present) if
+// NodeName(input_name) is equal to node_name. Returns -1 for control inputs.
+// Returns -2 if NodeName(input_name) is not equal to node_name.
+// Note: This function is used very heavily, and this hand-optimized
+// version is 3-4x faster than the version using Scanner, which it replaced.
+// This is worth the reduction in readability.
+inline int NodePositionIfSameNode(const string& input_name,
+ const string& node_name) {
+ if (input_name.empty()) return -2;
+ const bool is_ctrl = input_name[0] == '^';
+ auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin();
+ auto node_it = node_name.begin();
+ if (node_name.empty() ||
+ std::distance(input_it, input_name.end()) < node_name.size()) {
+ return -2;
+ }
+ while (node_it != node_name.end()) {
+ if (*input_it++ != *node_it++) {
+ return -2;
+ }
+ }
+ if (input_it == input_name.end()) {
+ return is_ctrl ? -1 : 0;
+ } else if (*input_it++ == ':') {
+ StringPiece remaining(&(*input_it),
+ std::distance(input_it, input_name.end()));
+ int position;
+ if (!strings::safe_strto32(remaining, &position)) {
+ return -2;
+ }
+ return is_ctrl ? -1 : position;
+ } else {
+ return -2;
+ }
+}
+
// Return the node name corresponding to 'name' if name is valid, or the empty
// string otherwise.
-string NodeName(const string& name);
+inline StringPiece NodeNameAsStringPiece(const string& name) {
+ static const string empty;
+ if (name.empty()) return StringPiece(empty);
+ const auto begin_it = name[0] == '^' ? name.begin() + 1 : name.begin();
+ auto end_it = begin_it;
+ while (end_it != name.end() && *end_it != ':') {
+ ++end_it;
+ }
+ if (end_it != name.end() && *end_it != ':') {
+ return StringPiece(empty);
+ }
+ return StringPiece(&(*begin_it), std::distance(begin_it, end_it));
+}
-// Get the trailing position number ":{digits}" (if any) of a node name.
-int NodePosition(const string& name);
+// Return the node name corresponding to 'name' if name is valid, or the empty
+// string otherwise.
+inline string NodeName(const string& name) {
+ return string(NodeNameAsStringPiece(name));
+}
+// Returns the node name and position in a single call.
inline StringPiece ParseNodeNameAsStringPiece(const string& name,
int* position) {
- // Strip the prefix '^' (if any), and strip the trailing ":{digits} (if any)
- // to get a node name.
- strings::Scanner scan(name);
- scan.ZeroOrOneLiteral("^")
- .RestartCapture()
- .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
- .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
- StringPiece capture;
- StringPiece remaining;
- if (scan.Peek(':') != ':' || !scan.GetResult(&remaining, &capture)) {
+ static const string empty;
+ if (name.empty()) {
*position = 0;
- static const string empty;
return StringPiece(empty);
- } else {
- if (name[0] == '^') {
- *position = -1;
- } else if (remaining.empty()) {
- *position = 0;
- } else {
- // Skip the first ':' character.
- CHECK(strings::safe_strto32(remaining.substr(1), position));
+ }
+ const bool is_ctrl = name[0] == '^';
+ const auto begin_it = is_ctrl ? name.begin() + 1 : name.begin();
+ *position = is_ctrl ? -1 : 0;
+ auto end_it = begin_it;
+ while (end_it != name.end() && *end_it != ':') {
+ ++end_it;
+ }
+ const StringPiece node_name(&(*begin_it), std::distance(begin_it, end_it));
+ if (end_it != name.end()) {
+ if (*end_it != ':') {
+ return StringPiece(empty);
+ } else if (!is_ctrl) {
+ ++end_it;
+ StringPiece remaining(&(*end_it), std::distance(end_it, name.end()));
+ if (!strings::safe_strto32(remaining, position)) {
+ return StringPiece(empty);
+ }
}
- return capture;
}
+ return node_name;
}
// Returns the node name and position in a single call.
@@ -142,6 +194,12 @@ inline string ParseNodeName(const string& name, int* position) {
return string(ParseNodeNameAsStringPiece(name, position));
}
+inline int NodePosition(const string& name) {
+ int position;
+ ParseNodeNameAsStringPiece(name, &position);
+ return position;
+}
+
// Add a prefix to a node name with a custom delimiter.
string AddPrefixToNodeName(const string& name, const string& prefix,
const string& delimiter);
diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD
index e540cc0476..bdbb8836e1 100644
--- a/tensorflow/core/grappler/utils/BUILD
+++ b/tensorflow/core/grappler/utils/BUILD
@@ -1,6 +1,10 @@
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_protos_grappler",
+)
cc_library(
name = "scc",
@@ -210,3 +214,28 @@ tf_cc_test(
"//tensorflow/core:testlib",
],
)
+
+cc_library(
+ name = "symbolic_shapes",
+ srcs = ["symbolic_shapes.cc"],
+ hdrs = ["symbolic_shapes.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ] + tf_protos_grappler(),
+)
+
+tf_cc_test(
+ name = "symbolic_shapes_test",
+ srcs = ["symbolic_shapes_test.cc"],
+ deps = [
+ ":symbolic_shapes",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc
index 910b0acaef..6266733f3e 100644
--- a/tensorflow/core/grappler/utils/grappler_test.cc
+++ b/tensorflow/core/grappler/utils/grappler_test.cc
@@ -30,13 +30,16 @@ GrapplerTest::GrapplerTest() {
// optimizations interfering in the comparison.
RewriterConfig* cfg =
options_.config.mutable_graph_options()->mutable_rewrite_options();
- cfg->set_constant_folding(RewriterConfig::OFF);
+ // TODO(rmlarsen): Add utility to generate config w/ all optimizers turned
+ // off.
cfg->set_arithmetic_optimization(RewriterConfig::OFF);
+ cfg->set_constant_folding(RewriterConfig::OFF);
+ cfg->set_debug_stripper(RewriterConfig::OFF);
cfg->set_dependency_optimization(RewriterConfig::OFF);
- cfg->set_loop_optimization(RewriterConfig::OFF);
cfg->set_function_optimization(RewriterConfig::OFF);
cfg->set_layout_optimizer(RewriterConfig::OFF);
- cfg->set_debug_stripper(RewriterConfig::OFF);
+ cfg->set_loop_optimization(RewriterConfig::OFF);
+ cfg->set_pin_to_host_optimization(RewriterConfig::OFF);
}
std::vector<Tensor> GrapplerTest::EvaluateNodes(
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.cc b/tensorflow/core/grappler/utils/symbolic_shapes.cc
index 155843a744..1666de4b80 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes.cc
+++ b/tensorflow/core/grappler/utils/symbolic_shapes.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.h b/tensorflow/core/grappler/utils/symbolic_shapes.h
index ace7bd1fe7..0a7d8ac82b 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes.h
+++ b/tensorflow/core/grappler/utils/symbolic_shapes.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
-#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
+#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
+#define TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
@@ -74,4 +74,4 @@ int64 ComputeSizeRatio(const TensorShapeProto& numerator,
} // namespace grappler
} // end namespace tensorflow
-#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
+#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc b/tensorflow/core/grappler/utils/symbolic_shapes_test.cc
index 7ce995d1c5..6ac644cdb1 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc
+++ b/tensorflow/core/grappler/utils/symbolic_shapes_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc
index c6e035834c..9b6c1f690b 100644
--- a/tensorflow/core/grappler/utils_test.cc
+++ b/tensorflow/core/grappler/utils_test.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace grappler {
@@ -147,6 +148,21 @@ TEST_F(UtilsTest, NodePosition) {
EXPECT_EQ(0, NodePosition(""));
}
+TEST_F(UtilsTest, NodePositionIfSameNode) {
+ EXPECT_EQ(-2, NodePositionIfSameNode(":123", ""));
+ EXPECT_EQ(-2, NodePositionIfSameNode(":", ""));
+ EXPECT_EQ(-2, NodePositionIfSameNode("", ""));
+ EXPECT_EQ(123, NodePositionIfSameNode("abc:123", "abc"));
+ EXPECT_EQ(-1, NodePositionIfSameNode("^abc", "abc"));
+ EXPECT_EQ(-1, NodePositionIfSameNode("^abc:123", "abc"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("abc", "xyz"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("abc", "abc/xyz"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("abc/xyz", "abc"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("abc:123", "xyz"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("^abc", "xyz"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("^abc:123", "xyz"));
+}
+
TEST_F(UtilsTest, AddNodeNamePrefix) {
EXPECT_EQ("OPTIMIZED/abc", AddPrefixToNodeName("abc", "OPTIMIZED"));
EXPECT_EQ("^OPTIMIZED/abc", AddPrefixToNodeName("^abc", "OPTIMIZED"));
@@ -209,7 +225,6 @@ TEST_F(UtilsTest, GetTailOfChain) {
auto noop = ops::NoOp(s.WithControlDependencies(neg0).WithOpName("noop"));
GraphDef graph;
TF_CHECK_OK(s.ToGraphDef(&graph));
- LOG(INFO) << graph.DebugString();
ASSERT_EQ("c0", graph.node(0).name());
ASSERT_EQ("c1", graph.node(1).name());
@@ -336,9 +351,45 @@ TEST_F(UtilsTest, NumNonControlOutputs) {
}
TEST_F(UtilsTest, DeleteNodes) {
- // TODO(rmlarsen): write forgtten test.
+ // TODO(rmlarsen): write forgotten test.
}
+#define BM_NodePositionIfSameNode(I, N, NAME) \
+ static void BM_NodePositionIfSameNode_##NAME(int iters) { \
+ string input = I; \
+ string node = N; \
+ for (int i = 0; i < iters; ++i) { \
+ const int pos = NodePositionIfSameNode(input, node); \
+ CHECK_GT(pos, -3); \
+ } \
+ } \
+ BENCHMARK(BM_NodePositionIfSameNode_##NAME)
+
+BM_NodePositionIfSameNode("foo/bar/baz:7", "foo/bar/baz", Match_7);
+BM_NodePositionIfSameNode("foo/bar/baz", "foo/bar/baz", Match_0);
+BM_NodePositionIfSameNode("^foo/bar/baz", "foo/bar/baz", Match_Ctrl);
+BM_NodePositionIfSameNode("blah", "foo/bar/baz", NoMatch_0);
+BM_NodePositionIfSameNode("foo/bar/baz/gnu", "foo/bar/baz", NoMatch_end);
+
+#define BM_ParseNodeNameAsStringPiece(I, NAME) \
+ static void BM_ParseNodeNameAsStringPiece_##NAME(int iters) { \
+ string input = I; \
+ for (int i = 0; i < iters; ++i) { \
+ int position; \
+ const StringPiece name = ParseNodeNameAsStringPiece(input, &position); \
+ CHECK_GE(position, -1); \
+ CHECK(!name.empty()); \
+ } \
+ } \
+ BENCHMARK(BM_ParseNodeNameAsStringPiece_##NAME)
+
+BM_ParseNodeNameAsStringPiece("foo", foo);
+BM_ParseNodeNameAsStringPiece("foo/bar/baz", foo_bar_baz);
+BM_ParseNodeNameAsStringPiece("^foo/bar/baz", foo_bar_baz_ctrl);
+BM_ParseNodeNameAsStringPiece("foo:123", foo123);
+BM_ParseNodeNameAsStringPiece("foo/bar/baz:123", foo_bar_baz_123);
+BM_ParseNodeNameAsStringPiece("^foo/bar/baz:123", foo_bar_baz_123_ctrl);
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 94d3ab4467..30171708c1 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -30,6 +30,7 @@ load(
"//tensorflow:tensorflow.bzl",
"if_android",
"tf_cc_test",
+ "tf_cc_test_mkl",
"tf_cc_tests",
"tf_cc_binary",
"tf_copts",
@@ -50,6 +51,10 @@ load(
"tf_kernel_tests_linkstatic",
)
load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
+load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
"if_mkl_ml",
@@ -212,6 +217,19 @@ tf_kernel_library(
],
)
+tf_kernel_library(
+ name = "extract_volume_patches_op",
+ prefix = "extract_volume_patches_op",
+ deps = [
+ ":bounds_check",
+ ":eigen_helpers",
+ ":ops_util",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
cc_library(
name = "conv_3d",
hdrs = ["conv_3d.h"],
@@ -617,6 +635,7 @@ cc_library(
":diag_op",
":edit_distance_op",
":extract_image_patches_op",
+ ":extract_volume_patches_op",
":gather_nd_op",
":gather_op",
":guarantee_const_op",
@@ -636,6 +655,7 @@ cc_library(
":reshape_op",
":reverse_op",
":reverse_sequence_op",
+ ":searchsorted_op",
":shape_ops",
":slice_op",
":snapshot_op",
@@ -869,6 +889,12 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "searchsorted_op",
+ prefix = "searchsorted_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
name = "inplace_ops",
prefix = "inplace_ops",
deps = ARRAY_DEPS,
@@ -1105,7 +1131,7 @@ tf_cuda_cc_test(
name = "depthwise_conv_ops_test",
size = "small",
srcs = ["depthwise_conv_ops_test.cc"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
":conv_ops",
":image",
@@ -2002,8 +2028,8 @@ tf_kernel_library(
":variable_ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:resource_variable_ops_op_lib",
- "//third_party/eigen3",
],
)
@@ -2702,6 +2728,7 @@ cc_library(
)
LOGGING_DEPS = [
+ "@com_google_absl//absl/strings",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -2759,6 +2786,7 @@ tf_cc_tests(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "@com_google_absl//absl/strings",
],
)
@@ -4021,11 +4049,6 @@ cc_library(
)
SPARSE_DEPS = [
- ":bounds_check",
- ":cwise_op",
- ":fill_functor",
- ":scatter_functor",
- "//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:sparse_ops_op_lib",
@@ -4058,7 +4081,9 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_cross_op",
prefix = "sparse_cross_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4070,13 +4095,19 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_dense_binary_op_shared",
prefix = "sparse_dense_binary_op_shared",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":cwise_op",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_sparse_binary_op_shared",
prefix = "sparse_sparse_binary_op_shared",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":cwise_op",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4108,7 +4139,9 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_softmax",
prefix = "sparse_softmax",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4120,25 +4153,37 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_tensor_dense_add_op",
prefix = "sparse_tensor_dense_add_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":scatter_functor",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_tensor_dense_matmul_op",
prefix = "sparse_tensor_dense_matmul_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":bounds_check",
+ ":fill_functor",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_to_dense_op",
prefix = "sparse_to_dense_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_xent_op",
prefix = "sparse_xent_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":bounds_check",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4396,17 +4441,27 @@ cc_library(
":reduce_join_op",
":regex_full_match_op",
":regex_replace_op",
+ ":string_format_op",
":string_join_op",
":string_length_op",
":string_split_op",
":string_strip_op",
":string_to_hash_bucket_op",
":substr_op",
+ ":unicode_script_op",
],
)
+cc_library(
+ name = "string_util",
+ srcs = ["string_util.cc"],
+ hdrs = ["string_util.h"],
+ deps = ["//tensorflow/core:lib"],
+)
+
STRING_DEPS = [
":bounds_check",
+ ":string_util",
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -4427,6 +4482,30 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "string_format_op",
+ prefix = "string_format_op",
+ deps = STRING_DEPS + ["@com_google_absl//absl/strings"],
+)
+
+tf_cc_test(
+ name = "string_format_op_test",
+ size = "small",
+ srcs = ["string_format_op_test.cc"],
+ deps = [
+ ":string_format_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
+tf_kernel_library(
name = "string_join_op",
prefix = "string_join_op",
deps = STRING_DEPS,
@@ -5113,6 +5192,7 @@ filegroup(
"spacetobatch_functor.h",
"spacetodepth_op.h",
"spectrogram.h",
+ "string_util.h",
"tensor_array.h",
"tile_functor.h",
"tile_ops_cpu_impl.h",
@@ -5192,6 +5272,8 @@ filegroup(
"cwise_op_squared_difference.cc",
"cwise_op_sub.cc",
"cwise_op_tanh.cc",
+ "cwise_op_xlogy.cc",
+ "cwise_op_xdivy.cc",
"data_format_ops.cc",
"decode_wav_op.cc",
"deep_conv2d.cc",
@@ -5281,6 +5363,7 @@ filegroup(
"spectrogram_op.cc",
"stack_ops.cc",
"string_join_op.cc",
+ "string_util.cc",
"summary_op.cc",
"tensor_array.cc",
"tensor_array_ops.cc",
@@ -5406,6 +5489,7 @@ filegroup(
"batch_kernels.*",
"regex_full_match_op.cc",
"regex_replace_op.cc",
+ "unicode_script_op.cc",
# Ops that are inherently incompatible with Android (e.g. tied to x86 platform).
"mkl_*",
"xsmm_*",
@@ -6228,6 +6312,26 @@ tf_mkl_kernel_library(
] + mkl_deps(),
)
+tf_cc_test_mkl(
+ name = "mkl_conv_ops_test",
+ size = "small",
+ srcs = ["mkl_conv_ops_test.cc"],
+ deps = [
+ ":ops_testutil",
+ ":ops_util",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
tf_mkl_kernel_library(
name = "mkl_tfconv_op",
prefix = "mkl_tfconv",
@@ -6331,6 +6435,12 @@ tf_mkl_kernel_library(
)
tf_mkl_kernel_library(
+ name = "mkl_slice_op",
+ prefix = "mkl_slice_op",
+ deps = ARRAY_DEPS + mkl_deps(),
+)
+
+tf_mkl_kernel_library(
name = "mkl_identity_op",
prefix = "mkl_identity_op",
deps = ARRAY_DEPS + mkl_deps(),
@@ -6474,6 +6584,16 @@ tf_kernel_library(
],
)
+tf_kernel_library(
+ name = "unicode_script_op",
+ srcs = ["unicode_script_op.cc"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:string_ops_op_lib",
+ "@icu//:common",
+ ],
+)
+
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
diff --git a/tensorflow/core/kernels/batch_matmul_op_complex.cc b/tensorflow/core/kernels/batch_matmul_op_complex.cc
index 54c45bfe63..f48bd0c318 100644
--- a/tensorflow/core/kernels/batch_matmul_op_complex.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_complex.cc
@@ -17,14 +17,18 @@ limitations under the License.
namespace tensorflow {
-#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY)
+// MKL_ML registers its own complex64/128 kernels in mkl_batch_matmul_op.cc
+// if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) && defined(ENABLE_MKL).
+// Anything else (the complement) should register the TF ones.
+// (MKL-DNN doesn't implement these kernels either.)
+#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) || !defined(ENABLE_MKL)
TF_CALL_complex64(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_CPU);
-#endif
+#endif // !INTEL_MKL || INTEL_MKL_DNN_ONLY || !ENABLE_MKL
#if GOOGLE_CUDA
TF_CALL_complex64(REGISTER_BATCH_MATMUL_GPU);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_GPU);
-#endif
+#endif // GOOGLE_CUDA
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/batch_matmul_op_real.cc
index 584b507c70..25ae795d8e 100644
--- a/tensorflow/core/kernels/batch_matmul_op_real.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_real.cc
@@ -21,10 +21,15 @@ limitations under the License.
namespace tensorflow {
-#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY)
+// MKL_ML registers its own float and double kernels in mkl_batch_matmul_op.cc
+// if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) && defined(ENABLE_MKL).
+// Anything else (the complement) should register the TF ones.
+// (MKL-DNN doesn't implement these kernels either.)
+#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) || !defined(ENABLE_MKL)
TF_CALL_float(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_double(REGISTER_BATCH_MATMUL_CPU);
-#endif
+#endif // !INTEL_MKL || INTEL_MKL_DNN_ONLY || !ENABLE_MKL
+
TF_CALL_half(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU);
diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD
index 792eb74e31..039b0db144 100644
--- a/tensorflow/core/kernels/batching_util/BUILD
+++ b/tensorflow/core/kernels/batching_util/BUILD
@@ -1,7 +1,7 @@
# Description: Utilities.
package(
- default_visibility = ["//tensorflow:internal"],
+ default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
@@ -12,7 +12,11 @@ cc_library(
name = "periodic_function_dynamic",
srcs = ["periodic_function.cc"],
hdrs = ["periodic_function.h"],
- visibility = ["//visibility:public"],
+ visibility = [
+ "//learning/serving:__subpackages__",
+ "//tensorflow:internal",
+ "//tensorflow_serving:__subpackages__",
+ ],
deps = [
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:protos_all_cc",
@@ -21,7 +25,11 @@ cc_library(
cc_library(
name = "periodic_function",
- visibility = ["//visibility:public"],
+ visibility = [
+ "//learning/serving:__subpackages__",
+ "//tensorflow:internal",
+ "//tensorflow_serving:__subpackages__",
+ ],
deps = [
":periodic_function_dynamic",
"//tensorflow/core:lib",
@@ -190,7 +198,11 @@ cc_library(
testonly = 1,
srcs = ["fake_clock_env.cc"],
hdrs = ["fake_clock_env.h"],
- visibility = ["//visibility:public"],
+ visibility = [
+ "//learning/serving:__subpackages__",
+ "//tensorflow:internal",
+ "//tensorflow_serving:__subpackages__",
+ ],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:tensorflow",
diff --git a/tensorflow/core/kernels/bincount_op_gpu.cu.cc b/tensorflow/core/kernels/bincount_op_gpu.cu.cc
index 6074b3e1f6..7d09e9b820 100644
--- a/tensorflow/core/kernels/bincount_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/bincount_op_gpu.cu.cc
@@ -17,7 +17,7 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "external/cub_archive/cub/device/device_histogram.cuh"
+#include "third_party/cub/device/device_histogram.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/boosted_trees/boosted_trees.proto b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
index c9664f0c1c..1ab72af059 100644
--- a/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
+++ b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
@@ -11,6 +11,7 @@ message Node {
oneof node {
Leaf leaf = 1;
BucketizedSplit bucketized_split = 2;
+ CategoricalSplit categorical_split = 3;
}
NodeMetadata metadata = 777;
}
@@ -57,6 +58,18 @@ message BucketizedSplit {
int32 right_id = 4;
}
+message CategoricalSplit {
+ // Categorical feature column and split describing the rule feature value ==
+ // value.
+ int32 feature_id = 1;
+ int32 value = 2;
+
+ // Node children indexing into a contiguous
+ // vector of nodes starting from the root.
+ int32 left_id = 3;
+ int32 right_id = 4;
+}
+
// Tree describes a list of connected nodes.
// Node 0 must be the root and can carry any payload including a leaf
// in the case of representing the bias.
diff --git a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
index b2efa06941..4ae26fb95b 100644
--- a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
@@ -334,30 +334,34 @@ class BoostedTreesExampleDebugOutputsOp : public OpKernel {
// Proto to store debug outputs, per example.
boosted_trees::DebugOutput example_debug_info;
// Initial bias prediction. E.g., prediction based off training mean.
- example_debug_info.add_logits_path(resource->GetTreeWeight(0) *
- resource->node_value(0, 0));
+ float tree_logit =
+ resource->GetTreeWeight(0) * resource->node_value(0, 0);
+ example_debug_info.add_logits_path(tree_logit);
int32 node_id = 0;
int32 tree_id = 0;
int32 feature_id;
- float tree_logit;
float past_trees_logit = 0; // Sum of leaf logits from prior trees.
- // Populate proto.
+ // Go through each tree and populate proto.
while (tree_id <= last_tree) {
- // Feature id used to split.
- feature_id = resource->feature_id(tree_id, node_id);
- example_debug_info.add_feature_ids(feature_id);
- // Get logit after split.
- node_id = resource->next_node(tree_id, node_id, i,
- batch_bucketized_features);
- tree_logit = resource->GetTreeWeight(tree_id) *
- resource->node_value(tree_id, node_id);
- // Output logit incorporates sum of leaf logits from prior trees.
- example_debug_info.add_logits_path(tree_logit + past_trees_logit);
- if (resource->is_leaf(tree_id, node_id)) {
- // Move onto other trees.
- past_trees_logit += tree_logit;
+ if (resource->is_leaf(tree_id, node_id)) { // Move onto other trees.
+ // Accumulate tree_logits only if the leaf is non-root, but do so
+ // for bias tree.
+ if (tree_id == 0 || node_id > 0) {
+ past_trees_logit += tree_logit;
+ }
++tree_id;
node_id = 0;
+ } else { // Add to proto.
+ // Feature id used to split.
+ feature_id = resource->feature_id(tree_id, node_id);
+ example_debug_info.add_feature_ids(feature_id);
+ // Get logit after split.
+ node_id = resource->next_node(tree_id, node_id, i,
+ batch_bucketized_features);
+ tree_logit = resource->GetTreeWeight(tree_id) *
+ resource->node_value(tree_id, node_id);
+ // Output logit incorporates sum of leaf logits from prior trees.
+ example_debug_info.add_logits_path(tree_logit + past_trees_logit);
}
}
// Set output as serialized proto containing debug info.
diff --git a/tensorflow/core/kernels/boosted_trees/resources.cc b/tensorflow/core/kernels/boosted_trees/resources.cc
index cc90bb2f45..2798722536 100644
--- a/tensorflow/core/kernels/boosted_trees/resources.cc
+++ b/tensorflow/core/kernels/boosted_trees/resources.cc
@@ -60,14 +60,26 @@ int32 BoostedTreesEnsembleResource::next_node(
DCHECK_LT(tree_id, tree_ensemble_->trees_size());
DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
- DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
- const auto& split = node.bucketized_split();
- if (bucketized_features[split.feature_id()](index_in_batch) <=
- split.threshold()) {
- return split.left_id();
- } else {
- return split.right_id();
+
+ switch (node.node_case()) {
+ case boosted_trees::Node::kBucketizedSplit: {
+ const auto& split = node.bucketized_split();
+ return (bucketized_features[split.feature_id()](index_in_batch) <=
+ split.threshold())
+ ? split.left_id()
+ : split.right_id();
+ }
+ case boosted_trees::Node::kCategoricalSplit: {
+ const auto& split = node.categorical_split();
+ return (bucketized_features[split.feature_id()](index_in_batch) ==
+ split.value())
+ ? split.left_id()
+ : split.right_id();
+ }
+ default:
+ DCHECK(false) << "Node type " << node.node_case() << " not supported.";
}
+ return -1;
}
float BoostedTreesEnsembleResource::node_value(const int32 tree_id,
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index e0da91125b..fa959b5a0e 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -132,6 +132,7 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel {
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
+ col_params_.instance.shape = c->input(0).shape();
// Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise
// the memory is not guaranteed to be unused by any concurrently executing
@@ -171,7 +172,7 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
- OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_));
+ OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_.instance.shape));
col_params_.is_source = true;
col_params_.instance.impl_details.subdiv_offsets = {0};
@@ -195,13 +196,14 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
if (c->mutable_output(0) == nullptr) {
// Allocate the output tensor, trying to reuse the input.
Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(
- c, c->forward_input_or_allocate_output({0}, 0, shape_, &output),
- done);
+ OP_REQUIRES_OK_ASYNC(c,
+ c->forward_input_or_allocate_output(
+ {0}, 0, col_params_.instance.shape, &output),
+ done);
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
OP_REQUIRES_ASYNC(
- c, shape_.IsSameSize(c->input(0).shape()),
+ c, col_params_.instance.shape.IsSameSize(c->input(0).shape()),
errors::Internal("Declared shape of op ", col_params_.name,
" does not match shape of input"),
done);
@@ -214,8 +216,6 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
}
private:
- TensorShape shape_;
-
TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastSendOpKernel);
};
@@ -234,7 +234,7 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
- OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_));
+ OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_.instance.shape));
col_params_.is_source = false;
col_params_.instance.impl_details.subdiv_offsets = {0};
@@ -258,7 +258,8 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
if (c->mutable_output(0) == nullptr) {
// No input, so must allocate output.
Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape_, &output), done);
+ OP_REQUIRES_OK_ASYNC(
+ c, c->allocate_output(0, col_params_.instance.shape, &output), done);
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
@@ -270,8 +271,6 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
}
private:
- TensorShape shape_;
-
TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastRecvOpKernel);
};
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index 6f5c8d8461..78856c4a99 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -264,150 +264,198 @@ class LaunchXsmmConvOp<CPUDevice, float> {
};
#endif
+#define TF_REQUIRES(EXP, STATUS) \
+ do { \
+ if (!TF_PREDICT_TRUE(EXP)) return (STATUS); \
+ } while (false)
+
+Status InitConv2DParameters(const OpKernelConstruction* context,
+ Conv2DParameters* params) {
+ TF_RETURN_IF_ERROR(context->GetAttr("dilations", &params->dilations));
+ TF_RETURN_IF_ERROR(context->GetAttr("strides", &params->strides));
+ TF_RETURN_IF_ERROR(context->GetAttr("padding", &params->padding));
+ string data_format_string;
+ TF_RETURN_IF_ERROR(context->GetAttr("data_format", &data_format_string));
+ TF_REQUIRES(FormatFromString(data_format_string, &params->data_format),
+ errors::InvalidArgument("Invalid data format"));
+
+ const auto& strides = params->strides;
+ const auto& dilations = params->dilations;
+ const auto& data_format = params->data_format;
+
+ TF_REQUIRES(dilations.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ TF_REQUIRES(strides.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ const int64 stride_n = GetTensorDim(strides, data_format, 'N');
+ const int64 stride_c = GetTensorDim(strides, data_format, 'C');
+ const int64 stride_h = GetTensorDim(strides, data_format, 'H');
+ const int64 stride_w = GetTensorDim(strides, data_format, 'W');
+ TF_REQUIRES(
+ stride_n == 1 && stride_c == 1,
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ TF_REQUIRES(stride_h > 0 && stride_w > 0,
+ errors::InvalidArgument(
+ "Row and column strides should be larger than 0."));
+
+ const int64 dilation_n = GetTensorDim(dilations, data_format, 'N');
+ const int64 dilation_c = GetTensorDim(dilations, data_format, 'C');
+ const int64 dilation_h = GetTensorDim(dilations, data_format, 'H');
+ const int64 dilation_w = GetTensorDim(dilations, data_format, 'W');
+ TF_REQUIRES(
+ dilation_n == 1 && dilation_c == 1,
+ errors::InvalidArgument("Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ TF_REQUIRES(
+ dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
+
+ return Status::OK();
+}
+
+Status ComputeConv2DDimension(const Conv2DParameters& params,
+ const Tensor& input, const Tensor& filter,
+ Conv2DDimensions* dimensions) {
+ // Check that 2D convolution input and filter have exactly 4 dimensions.
+ TF_REQUIRES(input.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input.shape().DebugString()));
+ TF_REQUIRES(filter.dims() == 4,
+ errors::InvalidArgument("filter must be 4-dimensional: ",
+ filter.shape().DebugString()));
+ for (int i = 0; i < 3; i++) {
+ TF_REQUIRES(
+ FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
+ errors::InvalidArgument("filter too large"));
+ }
+
+ // The last dimension for input is in_depth. Check that it is the same as the
+ // filter's in_depth or it is evenly divisible by filter's in_depth.
+ const int64 in_depth_raw = GetTensorDim(input, params.data_format, 'C');
+ const int64 patch_depth_raw = filter.dim_size(2);
+ TF_REQUIRES(FastBoundsCheck(in_depth_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input depth too large"));
+ TF_REQUIRES(FastBoundsCheck(patch_depth_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Patch depth too large"));
+ const int in_depth = static_cast<int>(in_depth_raw);
+ const int patch_depth = static_cast<int>(patch_depth_raw);
+ TF_REQUIRES(in_depth % patch_depth == 0,
+ errors::InvalidArgument(
+ "input depth must be evenly divisible by filter depth: ",
+ in_depth, " vs ", patch_depth));
+
+ // The last dimension for filter is out_depth.
+ const int out_depth = static_cast<int>(filter.dim_size(3));
+
+ // The second dimension for input is rows/height.
+ // The first dimension for filter is rows/height.
+ const int64 input_rows_raw = GetTensorDim(input, params.data_format, 'H');
+ TF_REQUIRES(FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input rows too large"));
+ const int input_rows = static_cast<int>(input_rows_raw);
+ const int filter_rows = static_cast<int>(filter.dim_size(0));
+
+ // The third dimension for input is columns/width.
+ // The second dimension for filter is columns/width.
+ const int64 input_cols_raw = GetTensorDim(input, params.data_format, 'W');
+ TF_REQUIRES(FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input cols too large"));
+ const int input_cols = static_cast<int>(input_cols_raw);
+ const int filter_cols = static_cast<int>(filter.dim_size(1));
+
+ // The first dimension for input is batch.
+ const int64 batch_raw = GetTensorDim(input, params.data_format, 'N');
+ TF_REQUIRES(FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("batch is too large"));
+ const int batch = static_cast<int>(batch_raw);
+
+ // Take the stride and dilation from the second and third dimensions only (we
+ // do not support striding or dilation on the batch or depth dimension).
+ const int stride_rows = GetTensorDim(params.strides, params.data_format, 'H');
+ const int stride_cols = GetTensorDim(params.strides, params.data_format, 'W');
+ const int dilation_rows =
+ GetTensorDim(params.dilations, params.data_format, 'H');
+ const int dilation_cols =
+ GetTensorDim(params.dilations, params.data_format, 'W');
+
+ // Compute windowed output sizes for rows and columns.
+ int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
+ input_rows, filter_rows, dilation_rows, stride_rows, params.padding,
+ &out_rows, &pad_rows));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
+ input_cols, filter_cols, dilation_cols, stride_cols, params.padding,
+ &out_cols, &pad_cols));
+
+ dimensions->batch = batch;
+ dimensions->input_rows = input_rows;
+ dimensions->input_cols = input_cols;
+ dimensions->in_depth = in_depth;
+ dimensions->filter_rows = filter_rows;
+ dimensions->filter_cols = filter_cols;
+ dimensions->patch_depth = patch_depth;
+ dimensions->out_depth = out_depth;
+ dimensions->stride_rows = stride_rows;
+ dimensions->stride_cols = stride_cols;
+ dimensions->dilation_rows = dilation_rows;
+ dimensions->dilation_cols = dilation_cols;
+ dimensions->out_rows = out_rows;
+ dimensions->out_cols = out_cols;
+ dimensions->pad_rows = pad_rows;
+ dimensions->pad_cols = pad_cols;
+
+ return Status::OK();
+}
+
+#undef TF_REQUIRES
+
template <typename Device, typename T>
class Conv2DOp : public BinaryOp<T> {
public:
explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
- OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
- OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
- string data_format;
- OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
- OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
- errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(context, InitConv2DParameters(context, &params_));
+
OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
use_cudnn_ &= CanUseCudnn();
cudnn_use_autotune_ = CudnnUseAutotune();
- OP_REQUIRES(context, dilations_.size() == 4,
- errors::InvalidArgument("Sliding window dilations field must "
- "specify 4 dimensions"));
- OP_REQUIRES(context, strides_.size() == 4,
- errors::InvalidArgument("Sliding window strides field must "
- "specify 4 dimensions"));
- const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
- const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
- const int64 stride_h = GetTensorDim(strides_, data_format_, 'H');
- const int64 stride_w = GetTensorDim(strides_, data_format_, 'W');
- OP_REQUIRES(
- context, stride_n == 1 && stride_c == 1,
- errors::InvalidArgument("Current implementation does not yet support "
- "strides in the batch and depth dimensions."));
- OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
- errors::InvalidArgument(
- "Row and column strides should be larger than 0."));
-
- const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
- const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
- const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
- const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
- OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
- errors::InvalidArgument(
- "Current implementation does not yet support "
- "dilations in the batch and depth dimensions."));
- OP_REQUIRES(
- context, dilation_h > 0 && dilation_w > 0,
- errors::InvalidArgument("Dilated rates should be larger than 0."));
- OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
void Compute(OpKernelContext* context) override {
// Input tensor is of the following dimensions:
// [ batch, in_rows, in_cols, in_depth ]
-
const Tensor& input = context->input(0);
// Input filter is of the following dimensions:
// [ filter_rows, filter_cols, in_depth, out_depth]
const Tensor& filter = context->input(1);
- // For 2D convolution, there should be 4 dimensions.
- OP_REQUIRES(context, input.dims() == 4,
- errors::InvalidArgument("input must be 4-dimensional",
- input.shape().DebugString()));
- OP_REQUIRES(context, filter.dims() == 4,
- errors::InvalidArgument("filter must be 4-dimensional: ",
- filter.shape().DebugString()));
-
- for (int i = 0; i < 3; i++) {
- OP_REQUIRES(
- context,
- FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
- errors::InvalidArgument("filter too large"));
- }
+ Conv2DDimensions dimensions;
+ OP_REQUIRES_OK(context,
+ ComputeConv2DDimension(params_, input, filter, &dimensions));
- // The last dimension for input is in_depth. It must be the same as the
- // filter's in_depth or be evenly divisible by filter's in_depth.
- const int64 in_depth = GetTensorDim(input, data_format_, 'C');
- const int64 patch_depth = filter.dim_size(2);
- OP_REQUIRES(context, in_depth % patch_depth == 0,
- errors::InvalidArgument(
- "input depth must be evenly divisible by filter depth: ",
- in_depth, " vs ", patch_depth));
-
- // The last dimension for filter is out_depth.
- const int out_depth = static_cast<int>(filter.dim_size(3));
-
- // The second dimension for input is rows/height.
- // The first dimension for filter is rows/height.
- const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H');
- OP_REQUIRES(
- context,
- FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("Input rows too large"));
- const int input_rows = static_cast<int>(input_rows_raw);
- const int filter_rows = static_cast<int>(filter.dim_size(0));
-
- // The third dimension for input is columns/width.
- // The second dimension for filter is columns/width.
- const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W');
- OP_REQUIRES(
- context,
- FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("Input cols too large"));
- const int input_cols = static_cast<int>(input_cols_raw);
- const int filter_cols = static_cast<int>(filter.dim_size(1));
-
- // The first dimension for input is batch.
- const int64 batch_raw = GetTensorDim(input, data_format_, 'N');
- OP_REQUIRES(context,
- FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("batch is too large"));
- const int batch = static_cast<int>(batch_raw);
-
- // For now we take the stride and dilation from the second and third
- // dimensions only (we do not support striding or dilation on the batch or
- // depth dimension).
- const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
- const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
-
- const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
- const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
-
- int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
- OP_REQUIRES_OK(context, GetWindowedOutputSizeV2(
- input_rows, filter_rows, dilation_rows,
- stride_rows, padding_, &out_rows, &pad_rows));
- OP_REQUIRES_OK(context, GetWindowedOutputSizeV2(
- input_cols, filter_cols, dilation_cols,
- stride_cols, padding_, &out_cols, &pad_cols));
- TensorShape out_shape =
- ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
+ TensorShape out_shape = ShapeFromFormat(
+ params_.data_format, dimensions.batch, dimensions.out_rows,
+ dimensions.out_cols, dimensions.out_depth);
// Output tensor is of the following dimensions:
// [ in_batch, out_rows, out_cols, out_depth ]
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
- VLOG(2) << "Conv2D: in_depth = " << in_depth
- << ", patch_depth = " << patch_depth
- << ", input_cols = " << input_cols
- << ", filter_cols = " << filter_cols
- << ", input_rows = " << input_rows
- << ", filter_rows = " << filter_rows
- << ", stride_rows = " << stride_rows
- << ", stride_cols = " << stride_cols
- << ", dilation_rows = " << dilation_rows
- << ", dilation_cols = " << dilation_cols
- << ", out_depth = " << out_depth;
+ VLOG(2) << "Conv2D: in_depth = " << dimensions.in_depth
+ << ", patch_depth = " << dimensions.patch_depth
+ << ", input_cols = " << dimensions.input_cols
+ << ", filter_cols = " << dimensions.filter_cols
+ << ", input_rows = " << dimensions.input_rows
+ << ", filter_rows = " << dimensions.filter_rows
+ << ", stride_rows = " << dimensions.stride_rows
+ << ", stride_cols = " << dimensions.stride_cols
+ << ", dilation_rows = " << dimensions.dilation_rows
+ << ", dilation_cols = " << dimensions.dilation_cols
+ << ", out_depth = " << dimensions.out_depth;
// If there is nothing to compute, return.
if (out_shape.num_elements() == 0) {
@@ -416,36 +464,41 @@ class Conv2DOp : public BinaryOp<T> {
#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
if (LaunchXsmmConvOp<Device, T>::Run(
- context, input, filter, batch, input_rows, input_cols, in_depth,
- filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols,
- out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols,
- output, data_format_)) {
+ context, input, filter, dimensions.batch, dimensions.input_rows,
+ dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
+ dimensions.filter_cols, dimensions.pad_rows, dimensions.pad_cols,
+ dimensions.out_rows, dimensions.out_cols, dimensions.out_depth,
+ dimensions.dilation_rows, dimensions.dilation_cols,
+ dimensions.stride_rows, dimensions.stride_cols, output,
+ params_.data_format)) {
return;
}
#endif
if (LaunchDeepConvOp<Device, T>::Run(
- context, input, filter, batch, input_rows, input_cols, in_depth,
- filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols,
- out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols,
- output, data_format_)) {
+ context, input, filter, dimensions.batch, dimensions.input_rows,
+ dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
+ dimensions.filter_cols, dimensions.pad_rows, dimensions.pad_cols,
+ dimensions.out_rows, dimensions.out_cols, dimensions.out_depth,
+ dimensions.dilation_rows, dimensions.dilation_cols,
+ dimensions.stride_rows, dimensions.stride_cols, output,
+ params_.data_format)) {
return;
}
launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
- dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
- output, data_format_);
+ dimensions.dilation_rows, dimensions.dilation_cols,
+ dimensions.stride_rows, dimensions.stride_cols, params_.padding,
+ output, params_.data_format);
}
private:
- std::vector<int32> dilations_;
- std::vector<int32> strides_;
+ Conv2DParameters params_;
bool use_cudnn_;
- Padding padding_;
- TensorFormat data_format_;
- LaunchConv2DOp<Device, T> launcher_;
bool cudnn_use_autotune_;
+ LaunchConv2DOp<Device, T> launcher_;
+
TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp);
};
@@ -731,9 +784,15 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
if (cudnn_use_autotune &&
!AutoTuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) {
std::vector<AlgorithmDesc> algorithms;
- CHECK(stream->parent()->GetConvolveAlgorithms(
- conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),
- &algorithms));
+ OP_REQUIRES(
+ ctx,
+ stream->parent()->GetConvolveAlgorithms(
+ conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
+ stream->parent()),
+ &algorithms),
+ errors::Unknown("Failed to get convolution algorithm. This is probably "
+ "because cuDNN failed to initialize, so try looking to "
+ "see if a warning log message was printed above."));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) {
diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h
index adf4601b43..7ec878e0b2 100644
--- a/tensorflow/core/kernels/conv_ops.h
+++ b/tensorflow/core/kernels/conv_ops.h
@@ -66,6 +66,50 @@ struct Im2ColBufferResource : public ResourceBase {
string DebugString() { return "Im2ColBufferResource"; }
};
+// Convolution parameters specified by Op attributes.
+struct Conv2DParameters {
+ std::vector<int32> dilations;
+ std::vector<int32> strides;
+ Padding padding;
+ TensorFormat data_format;
+};
+
+// Convolution dimensions inferred from parameters, input and filter tensors.
+struct Conv2DDimensions {
+ int batch;
+ int input_rows;
+ int input_cols;
+ int in_depth;
+
+ int filter_rows;
+ int filter_cols;
+ int patch_depth;
+ int out_depth;
+
+ int stride_rows;
+ int stride_cols;
+
+ int dilation_rows;
+ int dilation_cols;
+
+ int64 out_rows;
+ int64 out_cols;
+ int64 pad_rows;
+ int64 pad_cols;
+};
+
+// Initializes and validates Conv2D parameters configured by OpKernel
+// attributes.
+Status InitConv2DParameters(const OpKernelConstruction* context,
+ Conv2DParameters* params);
+
+// Computes and validates convolutions dimensions from Conv2D parameters. If
+// parameters are valid, dimensions will be updated with derived convolution
+// dimensions, otherwise error will be returned.
+Status ComputeConv2DDimension(const Conv2DParameters& params,
+ const Tensor& input, const Tensor& filter,
+ Conv2DDimensions* dimensions);
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_CONV_OPS_H_
diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc
index 5c2b88924b..83df4dce38 100644
--- a/tensorflow/core/kernels/conv_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_ops_3d.cc
@@ -435,10 +435,16 @@ struct LaunchConvOp<GPUDevice, T> {
if (cudnn_use_autotune && !AutoTuneConv3d::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
std::vector<AlgorithmDesc> algorithms;
- CHECK(stream->parent()->GetConvolveAlgorithms(
- conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
- stream->parent()),
- &algorithms));
+ OP_REQUIRES(ctx,
+ stream->parent()->GetConvolveAlgorithms(
+ conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
+ stream->parent()),
+ &algorithms),
+ errors::Unknown(
+ "Failed to get convolution algorithm. This is probably "
+ "because cuDNN failed to initialize, so try looking to "
+ "see if a warning log message was printed above."));
+
ProfileResult best_result;
ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) {
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h
index afc611f277..21d135decd 100644
--- a/tensorflow/core/kernels/conv_ops_gpu.h
+++ b/tensorflow/core/kernels/conv_ops_gpu.h
@@ -142,8 +142,12 @@ class ConvParameters {
template <typename T>
bool ShouldIncludeWinogradNonfusedAlgo(
se::StreamExecutor* stream_exec) const {
+ auto* dnn_support = stream_exec->AsDnn();
+ if (!dnn_support) {
+ return false;
+ }
// Skip this check for cuDNN 7 and newer.
- auto version = stream_exec->AsDnn()->GetVersion();
+ auto version = dnn_support->GetVersion();
if (version.ok() && version.ValueOrDie().major_version() >= 7) {
return true;
}
diff --git a/tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc
new file mode 100644
index 0000000000..e4b21a66c6
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY5(xdivy, Eigen::half, float, double, complex64, complex128);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc
new file mode 100644
index 0000000000..1e1b5a426e
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY5(xlogy, Eigen::half, float, double, complex64, complex128);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_xdivy.cc b/tensorflow/core/kernels/cwise_op_xdivy.cc
new file mode 100644
index 0000000000..6a6aec5e86
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_xdivy.cc
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER5(BinaryOp, CPU, "Xdivy", functor::xdivy, float, Eigen::half, double,
+ complex64, complex128);
+
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Xdivy").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
+ BinaryOp<SYCLDevice, functor::xdivy<TYPE>>);
+REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
+#undef REGISTER_SYCL_KERNEL
+
+#endif // TENSORFLOW_USE_SYCL
+
+#if GOOGLE_CUDA
+REGISTER5(BinaryOp, GPU, "Xdivy", functor::xdivy, float, Eigen::half, double,
+ complex64, complex128);
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_xlogy.cc b/tensorflow/core/kernels/cwise_op_xlogy.cc
new file mode 100644
index 0000000000..e71a9109b2
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_xlogy.cc
@@ -0,0 +1,41 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER5(BinaryOp, CPU, "Xlogy", functor::xlogy, float, Eigen::half, double,
+ complex64, complex128);
+
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Xlogy").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
+ BinaryOp<SYCLDevice, functor::xlogy<TYPE>>);
+REGISTER_SYCL_KERNEL(Eigen::half);
+REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
+REGISTER_SYCL_KERNEL(complex64);
+REGISTER_SYCL_KERNEL(complex128);
+#undef REGISTER_SYCL_KERNEL
+
+#endif // TENSORFLOW_USE_SYCL
+
+#if GOOGLE_CUDA
+REGISTER5(BinaryOp, GPU, "Xlogy", functor::xlogy, float, Eigen::half, double,
+ complex64, complex128);
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 22eb66e979..66ba827a90 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -471,6 +471,45 @@ struct functor_traits<bitwise_xor_op<Scalar>> {
enum { Cost = Eigen::NumTraits<Scalar>::AddCost, PacketAccess = true };
};
+// TODO(srvasude): Add packet versions of this operation.
+template <typename Scalar>
+struct xlogy_op {
+ EIGEN_EMPTY_STRUCT_CTOR(xlogy_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
+ operator()(const Scalar& x, const Scalar& y) const {
+ if (x == Scalar(0.)) {
+ return Scalar(0.);
+ }
+ return x * numext::log(y);
+ }
+};
+
+template <typename Scalar>
+struct functor_traits<xlogy_op<Scalar>> {
+ enum {
+ Cost = (sizeof(Scalar) == 4 ? 40 : 85) + Eigen::NumTraits<Scalar>::MulCost,
+ PacketAccess = false
+ };
+};
+
+template <typename Scalar>
+// TODO(srvasude): Add packet versions of this operation.
+struct xdivy_op {
+ EIGEN_EMPTY_STRUCT_CTOR(xdivy_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
+ operator()(const Scalar& x, const Scalar& y) const {
+ if (x == Scalar(0.)) {
+ return Scalar(0.);
+ }
+ return x / y;
+ }
+};
+
+template <typename Scalar>
+struct functor_traits<xdivy_op<Scalar>> {
+ enum { Cost = Eigen::NumTraits<Scalar>::MulCost, PacketAccess = false };
+};
+
} // end namespace internal
} // end namespace Eigen
@@ -830,6 +869,12 @@ struct squared_difference
Eigen::internal::scalar_difference_op<T>>> {};
template <typename T>
+struct xdivy : base<T, Eigen::internal::xdivy_op<T>> {};
+
+template <typename T>
+struct xlogy : base<T, Eigen::internal::xlogy_op<T>> {};
+
+template <typename T>
struct less : base<T, Eigen::internal::less<T>, bool> {};
template <typename T>
diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc
index 980edffceb..8ad3b4d1fc 100644
--- a/tensorflow/core/kernels/cwise_ops_common.cc
+++ b/tensorflow/core/kernels/cwise_ops_common.cc
@@ -20,9 +20,9 @@ namespace tensorflow {
BinaryOpShared::BinaryOpShared(OpKernelConstruction* ctx, DataType out,
DataType in)
: OpKernel(ctx) {
-#ifndef INTEL_MKL
+#if !defined(INTEL_MKL) || !defined(ENABLE_MKL)
OP_REQUIRES_OK(ctx, ctx->MatchSignature({in, in}, {out}));
-#endif
+#endif // !INTEL_MKL || !ENABLE_MKL
}
void BinaryOpShared::SetUnimplementedError(OpKernelContext* ctx) {
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index b3c359010d..6333853cdf 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -628,6 +628,20 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "multi_device_iterator_ops",
+ srcs = ["multi_device_iterator_ops.cc"],
+ deps = [
+ ":dataset",
+ ":dataset_utils",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
+tf_kernel_library(
name = "optional_ops",
srcs = ["optional_ops.cc"],
hdrs = ["optional_ops.h"],
@@ -722,6 +736,7 @@ tf_kernel_library(
":map_dataset_op",
":map_defun_op",
":model_dataset_op",
+ ":multi_device_iterator_ops",
":optimize_dataset_op",
":optional_ops",
":padded_batch_dataset_op",
@@ -750,6 +765,7 @@ tf_kernel_library(
":window_dataset_op",
":writer_ops",
":zip_dataset_op",
+ "//tensorflow/core/kernels/data/experimental:dataset_kernels",
],
)
diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc
index 887b8c8365..d1db1d7bec 100644
--- a/tensorflow/core/kernels/data/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/batch_dataset_op.cc
@@ -117,7 +117,7 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- SetMetadata(ctx, "batch_size", dataset()->batch_size_);
+ AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index 31c8f5c0ea..0bb929b3ce 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -22,39 +22,96 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/notification.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
-/* static */
-Status CapturedFunction::Create(
- const NameAttrList& func, std::vector<Tensor> captured_inputs,
- std::unique_ptr<CapturedFunction>* out_function) {
- return Create(func, std::move(captured_inputs), true, out_function);
-}
+namespace {
+
+// Simplistic implementation of the `StepStatsCollectorInterface` that only
+// cares about collecting the CPU time needed to execute a captured function.
+class SimpleStepStatsCollector : public StepStatsCollectorInterface {
+ public:
+ void IncrementProcessingTime(int64 delta) {
+ mutex_lock l(mu_);
+ processing_time_ += delta;
+ }
+
+ NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override {
+ return new SimpleNodeExecStats(this);
+ }
+
+ string ReportAllocsOnResourceExhausted(const string& err) override {
+ return "";
+ }
+
+ int64 processing_time() {
+ tf_shared_lock l(mu_);
+ return processing_time_;
+ }
+
+ private:
+ class SimpleNodeExecStats : public NodeExecStatsInterface {
+ public:
+ explicit SimpleNodeExecStats(SimpleStepStatsCollector* step_stats_collector)
+ : step_stats_collector_(step_stats_collector) {}
+
+ void Done(const string& device) override {
+ step_stats_collector_->IncrementProcessingTime(end_time_ns_ -
+ start_time_ns_);
+ delete this;
+ }
+
+ void RecordExecutorStarted() override {
+ start_time_ns_ = Env::Default()->NowNanos();
+ }
+
+ void RecordComputeStarted() override {}
+
+ void RecordComputeEnded() override {}
+
+ void RecordExecutorEnded() override {
+ end_time_ns_ = Env::Default()->NowNanos();
+ }
+
+ void SetMemory(OpKernelContext* ctx) override {}
+
+ void SetOutput(int slot, const Tensor* tensor) override {}
+
+ void SetReferencedTensors(const TensorReferenceVector& tensors) override {}
+
+ void SetScheduled(int64 nanos) override {}
+
+ private:
+ int64 start_time_ns_ = 0;
+ int64 end_time_ns_ = 0;
+ SimpleStepStatsCollector* step_stats_collector_; // Not owned.
+ };
+
+ mutex mu_;
+ int64 processing_time_ GUARDED_BY(mu_) = 0;
+};
+
+} // namespace
/* static */
Status CapturedFunction::Create(
- const NameAttrList& func, std::vector<Tensor> captured_inputs,
- bool use_inter_op_parallelism,
+ const NameAttrList& func, OpKernelContext* ctx, const string& argument,
std::unique_ptr<CapturedFunction>* out_function) {
- out_function->reset(new CapturedFunction(func, std::move(captured_inputs),
- use_inter_op_parallelism));
- return Status::OK();
+ return CapturedFunction::Create(func, ctx, argument, true, out_function);
}
-/* static */
Status CapturedFunction::Create(
const NameAttrList& func, OpKernelContext* ctx, const string& argument,
+ bool use_inter_op_parallelism,
std::unique_ptr<CapturedFunction>* out_function) {
- OpInputList argument_inputs;
- TF_RETURN_IF_ERROR(ctx->input_list(argument, &argument_inputs));
- std::vector<Tensor> arguments_t;
- arguments_t.reserve(argument_inputs.size());
- for (const Tensor& t : argument_inputs) {
- arguments_t.push_back(t);
- }
- return CapturedFunction::Create(func, std::move(arguments_t), out_function);
+ OpInputList inputs;
+ TF_RETURN_IF_ERROR(ctx->input_list(argument, &inputs));
+ std::vector<Tensor> arguments(inputs.begin(), inputs.end());
+ *out_function = WrapUnique(new CapturedFunction(func, std::move(arguments),
+ use_inter_op_parallelism));
+ return Status::OK();
}
CapturedFunction::~CapturedFunction() {
@@ -370,13 +427,13 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
done(s);
return;
}
- auto frame =
+ OwnedArgsCallFrame* frame =
new OwnedArgsCallFrame(std::move(args), &captured_inputs_, ret_types_);
FunctionLibraryRuntime::Options f_opts;
f_opts.step_id = CapturedFunction::generate_step_id();
ResourceMgr* resource_mgr = ctx->lib()->device()->resource_manager();
- auto step_container = new ScopedStepContainer(
+ ScopedStepContainer* step_container = new ScopedStepContainer(
f_opts.step_id, [resource_mgr](const string& name) {
resource_mgr->Cleanup(name).IgnoreError();
});
@@ -391,24 +448,19 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
// (such as queue kernels) that depend on the non-nullness of
// `OpKernelContext::cancellation_manager()`, but additional effort
// will be required to plumb it through the `IteratorContext`.
- auto c_mgr = new CancellationManager;
+ CancellationManager* c_mgr = new CancellationManager;
f_opts.cancellation_manager = c_mgr;
- StepStats* stats = nullptr;
- StepStatsCollector* stats_collector = nullptr;
- std::shared_ptr<model::Node> node;
+ std::shared_ptr<SimpleStepStatsCollector> stats_collector;
if (ctx->model()) {
- node = ctx->model()->LookupNode(prefix);
- if (node) {
- // TODO(b/114104975): Use something light-weight here.
- stats = new StepStats();
- stats_collector = new StepStatsCollector(stats);
- }
+ stats_collector = MakeUnique<SimpleStepStatsCollector>();
}
- f_opts.stats_collector = stats_collector;
+ f_opts.stats_collector = stats_collector.get();
auto callback = std::bind(
- [rets, step_container, c_mgr, frame, stats, stats_collector, node](
- FunctionLibraryRuntime::DoneCallback done,
+ [rets, step_container, c_mgr, frame](
+ const FunctionLibraryRuntime::DoneCallback& done,
+ const std::shared_ptr<model::Model>& model, const string& prefix,
+ const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
// Begin unbound arguments.
Status s) {
delete step_container;
@@ -417,25 +469,17 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
s = frame->ConsumeRetvals(rets);
}
delete frame;
- if (node) {
- int64 delta = 0;
- stats_collector->Finalize();
- for (auto dev_stats : stats->dev_stats()) {
- for (auto node_stats : dev_stats.node_stats()) {
- delta += node_stats.all_end_rel_nanos();
- }
- }
- delete stats_collector;
- delete stats;
- node->add_processing_time(delta);
- node->start_work();
+ if (model) {
+ model->AddProcessingTime(prefix, stats_collector->processing_time());
+ model->RecordStart(prefix, false /* stop_output */);
}
done(s);
- if (node) {
- node->stop_work();
+ if (model) {
+ model->RecordStop(prefix, false /* start_output */);
}
},
- std::move(done), std::placeholders::_1);
+ std::move(done), ctx->model(), prefix, std::move(stats_collector),
+ std::placeholders::_1);
ctx->lib()->Run(f_opts, handle, frame, std::move(callback));
}
diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h
index 8b420fa5db..a10376bf97 100644
--- a/tensorflow/core/kernels/data/captured_function.h
+++ b/tensorflow/core/kernels/data/captured_function.h
@@ -42,27 +42,19 @@ namespace data {
// context.
class CapturedFunction {
public:
- // Creates a new instance from a list of named attributes and captured inputs.
- //
- // NOTE(mrry): The `captured_inputs` are passed by value. For
- // efficiency, you are recommended to move this argument into the call.
- static Status Create(const NameAttrList& func,
- std::vector<Tensor> captured_inputs,
+ // Creates a new instance using a list of named attributes, fetching captured
+ // inputs from a context argument.
+ static Status Create(const NameAttrList& func, OpKernelContext* ctx,
+ const string& argument,
std::unique_ptr<CapturedFunction>* out_function);
- // Creates a new instance from a list of named attributes and captured inputs.
+ // Creates a new instance using a list of named attributes, fetching captured
+ // inputs from a context argument.
//
// If `use_inter_op_parallelism` is false, the runtime may use an executor
// that is optimized for small functions.
- static Status Create(const NameAttrList& func,
- std::vector<Tensor> captured_inputs,
- bool use_inter_op_parallelism,
- std::unique_ptr<CapturedFunction>* out_function);
-
- // Creates a new instance using a list of named attributes, fetching captured
- // inputs from a context argument.
static Status Create(const NameAttrList& func, OpKernelContext* ctx,
- const string& argument,
+ const string& argument, bool use_inter_op_parallelism,
std::unique_ptr<CapturedFunction>* out_function);
~CapturedFunction();
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index e7ac368ae3..e10833f525 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -44,5 +44,42 @@ Status MakeIteratorFromInputElement(
ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator);
}
+Status VerifyTypesMatch(const DataTypeVector& expected,
+ const DataTypeVector& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " types but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (expected[i] != received[i]) {
+ return errors::InvalidArgument("Data type mismatch at component ", i,
+ ": expected ", DataTypeString(expected[i]),
+ " but got ", DataTypeString(received[i]),
+ ".");
+ }
+ }
+ return Status::OK();
+}
+
+Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
+ const std::vector<PartialTensorShape>& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " shapes but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (!expected[i].IsCompatibleWith(received[i])) {
+ return errors::InvalidArgument("Incompatible shapes at component ", i,
+ ": expected ", expected[i].DebugString(),
+ " but got ", received[i].DebugString(),
+ ".");
+ }
+ }
+
+ return Status::OK();
+}
+
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index 234856ea39..6ec1350cd4 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -27,6 +27,16 @@ Status MakeIteratorFromInputElement(
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
std::unique_ptr<IteratorBase>* out_iterator);
+// Returns Status::OK() if `expected` and `received` types match,
+// errors::InvalidArgument otherwise.
+Status VerifyTypesMatch(const DataTypeVector& expected,
+ const DataTypeVector& received);
+
+// Returns Status::OK() if `expected` and `received` shapes are compatible,
+// errors::InvalidArgument otherwise.
+Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
+ const std::vector<PartialTensorShape>& received);
+
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/core/kernels/data/experimental/BUILD
index ec6cb37193..43406db3ed 100644
--- a/tensorflow/contrib/data/kernels/BUILD
+++ b/tensorflow/core/kernels/data/experimental/BUILD
@@ -1,22 +1,26 @@
# Description:
-# Contains kernels for datasets and iterators.
+# Contains experimental kernels for datasets and iterators.
package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_kernel_library",
+)
+
cc_library(
name = "indexed_dataset_headers",
hdrs = ["indexed_dataset.h"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:framework",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
)
-cc_library(
+tf_kernel_library(
name = "indexed_dataset",
srcs = [
"identity_indexed_dataset.cc",
@@ -24,103 +28,102 @@ cc_library(
],
deps = [
":indexed_dataset_headers",
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "prefetching_kernels",
srcs = ["prefetching_kernels.cc"],
deps = [
- "//tensorflow/core:core_cpu_headers_lib",
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "directed_interleave_dataset_op",
srcs = ["directed_interleave_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "csv_dataset_op",
srcs = ["csv_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "ignore_errors_dataset_op",
srcs = ["ignore_errors_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "lmdb_dataset_op",
srcs = ["lmdb_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/eigen3",
"@lmdb",
- "@protobuf_archive//:protobuf_headers",
],
)
-cc_library(
+tf_kernel_library(
name = "threadpool_dataset_op",
srcs = ["threadpool_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "unique_dataset_op",
srcs = ["unique_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "assert_next_dataset_op",
srcs = ["assert_next_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "dataset_kernels",
deps = [
":assert_next_dataset_op",
@@ -132,8 +135,5 @@ cc_library(
":prefetching_kernels",
":threadpool_dataset_op",
":unique_dataset_op",
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
)
diff --git a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc
index c19a609780..3511cca0f5 100644
--- a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc
@@ -147,8 +147,9 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel {
std::vector<PartialTensorShape> output_shapes_;
};
-REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU),
- AssertNextDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalAssertNextDataset").Device(DEVICE_CPU),
+ AssertNextDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc
index 21ec50fb6b..7451ca4cb1 100644
--- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc
@@ -852,7 +852,8 @@ class CSVDatasetOp : public DatasetOpKernel {
}; // class CSVDatasetOp
// Register the kernel implementation for CSVDataset.
-REGISTER_KERNEL_BUILDER(Name("CSVDataset").Device(DEVICE_CPU), CSVDatasetOp);
+REGISTER_KERNEL_BUILDER(Name("ExperimentalCSVDataset").Device(DEVICE_CPU),
+ CSVDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
index a5321620bf..c47a9099c4 100644
--- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
@@ -272,8 +272,9 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU),
- DirectedInterleaveDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU),
+ DirectedInterleaveDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc b/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc
index c3cb45dbf7..2141f118ca 100644
--- a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc
+++ b/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/data/kernels/indexed_dataset.h"
+#include "tensorflow/core/kernels/data/experimental/indexed_dataset.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -147,8 +147,9 @@ class IdentityIndexedDatasetOp : public IndexedDatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("IdentityIndexedDataset").Device(DEVICE_CPU),
- IdentityIndexedDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIdentityIndexedDataset").Device(DEVICE_CPU),
+ IdentityIndexedDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc
index beec344534..b34377c642 100644
--- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc
@@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
namespace data {
@@ -133,8 +132,9 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("IgnoreErrorsDataset").Device(DEVICE_CPU),
- IgnoreErrorsDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIgnoreErrorsDataset").Device(DEVICE_CPU),
+ IgnoreErrorsDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.cc b/tensorflow/core/kernels/data/experimental/indexed_dataset.cc
index ced8ab0d60..75ea462f40 100644
--- a/tensorflow/contrib/data/kernels/indexed_dataset.cc
+++ b/tensorflow/core/kernels/data/experimental/indexed_dataset.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/data/kernels/indexed_dataset.h"
+#include "tensorflow/core/kernels/data/experimental/indexed_dataset.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -361,12 +361,14 @@ class IndexedDatasetGet : public OpKernel {
};
REGISTER_KERNEL_BUILDER(
- Name("MaterializedIndexDatasetHandle").Device(DEVICE_CPU),
+ Name("ExperimentalMaterializedIndexDatasetHandle").Device(DEVICE_CPU),
MaterializedHandleOp);
-REGISTER_KERNEL_BUILDER(Name("IndexedDatasetMaterialize").Device(DEVICE_CPU),
- MaterializeDatasetOp);
-REGISTER_KERNEL_BUILDER(Name("IndexedDatasetGet").Device(DEVICE_CPU),
- IndexedDatasetGet);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIndexedDatasetMaterialize").Device(DEVICE_CPU),
+ MaterializeDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIndexedDatasetGet").Device(DEVICE_CPU),
+ IndexedDatasetGet);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.h b/tensorflow/core/kernels/data/experimental/indexed_dataset.h
index 7aa2d3fdbc..27a8360cbc 100644
--- a/tensorflow/contrib/data/kernels/indexed_dataset.h
+++ b/tensorflow/core/kernels/data/experimental/indexed_dataset.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
-#define TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -116,4 +116,4 @@ Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset,
} // namespace data
} // namespace tensorflow
-#endif // TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
+#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
diff --git a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc
index d233c1f8ec..8a88d32f0c 100644
--- a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc
@@ -210,7 +210,8 @@ class LMDBDatasetOp : public DatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), LMDBDatasetOp);
+REGISTER_KERNEL_BUILDER(Name("ExperimentalLMDBDataset").Device(DEVICE_CPU),
+ LMDBDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc
new file mode 100644
index 0000000000..2c6179d9f5
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc
@@ -0,0 +1,482 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <deque>
+
+#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_op_kernel.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+struct BufferElement {
+ // The producer sets `status` if getting the input element fails.
+ Status status;
+ // The buffered data element.
+ std::vector<Tensor> value;
+};
+
+using FunctionBufferCallback = std::function<void(const BufferElement&)>;
+
+class FunctionBufferingResource : public ResourceBase {
+ public:
+ FunctionBufferingResource(FunctionLibraryRuntime* lib,
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
+ const NameAttrList& func, int64 buffer_size,
+ const string& source_device,
+ const string& target_device,
+ const std::vector<Tensor>& func_args,
+ const DataTypeVector& output_types)
+ : lib_(lib),
+ pflr_(std::move(pflr)),
+ func_(func),
+ buffer_size_(buffer_size),
+ source_device_(source_device),
+ target_device_(target_device),
+ func_args_(func_args),
+ output_types_(output_types),
+ handle_(kInvalidHandle),
+ is_buffering_(false),
+ end_of_sequence_(false),
+ cancelled_(false) {}
+
+ ~FunctionBufferingResource() override {
+ Cancel();
+ }
+
+ string DebugString() override {
+ return strings::StrCat("FunctionBufferingResource. Size: ", buffer_size_,
+ "; target_device: ", target_device_);
+ }
+
+ // Instantiates the function the first time it's called. After that it caches
+ // the handle.
+ Status Instantiate() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ // Re-use existing handle if it's been set, effectively caching it.
+ if (handle_ != kInvalidHandle) {
+ return Status::OK();
+ }
+ AttrValueMap attr_values = func_.attr();
+ FunctionLibraryRuntime::InstantiateOptions opts;
+ opts.target = target_device_;
+ return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), opts,
+ &handle_);
+ }
+
+ // Returns true if we've got to the end of the sequence and exhausted the
+ // buffer.
+ bool Finished() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ return end_of_sequence_ && buffer_.empty();
+ }
+
+ // Cancels any buffering / prefetching going on.
+ void Cancel() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ cancelled_ = true;
+ while (is_buffering_) {
+ cond_var_.wait(l);
+ }
+ }
+
+ // Cancels all pending operations and then clears out the state.
+ void Reset() LOCKS_EXCLUDED(mu_) {
+ Cancel();
+ mutex_lock l(mu_);
+ buffer_.clear();
+ requests_.clear();
+ is_buffering_ = false;
+ end_of_sequence_ = false;
+ cancelled_ = false;
+ }
+
+ // If the buffer has anything, runs `callback` on the first element in the
+ // buffer, else schedules the `callback` to be called. Requires `args` and
+ // `lib` in case more function calls need to be scheduled.
+ void MaybeGet(FunctionBufferCallback callback) LOCKS_EXCLUDED(mu_) {
+ bool start_buffering = false;
+ bool produced_output = false;
+ BufferElement buffer_element;
+ {
+ mutex_lock l(mu_);
+ if (!is_buffering_ && !end_of_sequence_) {
+ start_buffering = true;
+ }
+ if (!buffer_.empty()) {
+ produced_output = true;
+ std::swap(buffer_element, buffer_.front());
+ buffer_.pop_front();
+ } else {
+ produced_output = false;
+ requests_.push_back(std::move(callback));
+ }
+ }
+ if (produced_output) {
+ callback(buffer_element);
+ }
+ if (start_buffering) {
+ FillBuffer();
+ }
+ }
+
+ private:
+ void FillBuffer() LOCKS_EXCLUDED(mu_) {
+ FunctionLibraryRuntime::Handle handle;
+ std::vector<FunctionBufferCallback> cancellation_callbacks;
+ std::vector<BufferElement> cancellation_buffer_elements;
+ bool cancelled = false;
+ {
+ mutex_lock l(mu_);
+ handle = handle_;
+ if (cancelled_) {
+ cancelled = true;
+ // Run through and fulfill all pending requests, if possible.
+ while (!requests_.empty()) {
+ if (!buffer_.empty()) {
+ cancellation_buffer_elements.push_back(std::move(buffer_.front()));
+ buffer_.pop_front();
+ cancellation_callbacks.push_back(std::move(requests_.front()));
+ requests_.pop_front();
+ } else {
+ LOG(ERROR) << "Buffer ran out of elements and we couldn't satisfy: "
+ << requests_.size() << " requests";
+ break;
+ }
+ }
+ is_buffering_ = false;
+ } else {
+ is_buffering_ = true;
+ }
+ }
+ if (cancelled) {
+ for (int i = 0; i < cancellation_callbacks.size(); ++i) {
+ cancellation_callbacks[i](cancellation_buffer_elements[i]);
+ }
+ cond_var_.notify_all();
+ return;
+ }
+ FunctionLibraryRuntime::Options opts;
+ // Copied from CapturedFunction::generate_step_id();
+ opts.step_id = -std::abs(static_cast<int64>(random::New64()));
+ opts.source_device = source_device_;
+ AllocatorAttributes arg_alloc_attr;
+ arg_alloc_attr.set_on_host(true);
+ opts.args_alloc_attrs.push_back(arg_alloc_attr);
+ for (const auto& dtype : output_types_) {
+ AllocatorAttributes ret_alloc_attrs;
+ if (DataTypeAlwaysOnHost(dtype)) {
+ ret_alloc_attrs.set_on_host(true);
+ }
+ opts.rets_alloc_attrs.push_back(ret_alloc_attrs);
+ }
+ if (opts.source_device != target_device_) {
+ opts.remote_execution = true;
+ }
+ opts.create_rendezvous = true;
+ auto* rets = new std::vector<Tensor>;
+ lib_->Run(opts, handle, func_args_, rets,
+ [this, rets](const Status& status) {
+ FunctionBufferCallback callback = nullptr;
+ BufferElement buffer_front;
+ bool restart_buffering = false;
+ {
+ mutex_lock l(mu_);
+ BufferElement buffer_element;
+ buffer_element.status = status;
+ if (status.ok()) {
+ buffer_element.value.swap(*rets);
+ } else {
+ end_of_sequence_ = true;
+ is_buffering_ = false;
+ }
+ buffer_.push_back(std::move(buffer_element));
+ if (!requests_.empty()) {
+ buffer_front = std::move(buffer_.front());
+ buffer_.pop_front();
+ callback = std::move(requests_.front());
+ requests_.pop_front();
+ }
+ if (buffer_.size() < buffer_size_ && !end_of_sequence_) {
+ restart_buffering = true;
+ } else {
+ // When the buffer is full, we don't want to call
+ // FillBuffer() unless we're in cancellation phase in which
+ // case FillBuffer() will do the final cleanup post
+ // cancellation.
+ if (cancelled_) {
+ restart_buffering = true;
+ }
+ is_buffering_ = false;
+ }
+ }
+ if (callback != nullptr) {
+ callback(buffer_front);
+ }
+ if (restart_buffering) {
+ FillBuffer();
+ }
+ });
+ }
+
+ mutex mu_;
+ FunctionLibraryRuntime* lib_;
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+ NameAttrList func_;
+ const int64 buffer_size_;
+ const string source_device_;
+ const string target_device_;
+ const std::vector<Tensor> func_args_;
+ const DataTypeVector output_types_;
+ FunctionLibraryRuntime::Handle handle_ GUARDED_BY(mu_);
+ std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
+ std::deque<FunctionBufferCallback> requests_ GUARDED_BY(mu_);
+ bool is_buffering_ GUARDED_BY(mu_);
+ bool end_of_sequence_ GUARDED_BY(mu_);
+ bool cancelled_ GUARDED_BY(mu_);
+ condition_variable cond_var_;
+};
+
+class FunctionBufferResourceHandleOp : public OpKernel {
+ public:
+ explicit FunctionBufferResourceHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), flib_def_(nullptr) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ }
+
+ ~FunctionBufferResourceHandleOp() override {
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->Delete<FunctionBufferingResource>(cinfo_.container(),
+ cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* string_arg;
+ OP_REQUIRES_OK(ctx, ctx->input("string_arg", &string_arg));
+ std::vector<Tensor> func_args;
+ func_args.push_back(*string_arg);
+
+ const string& source_device = ctx->device()->name();
+
+ // Obtain and canonicalize target_device.
+ const Tensor* target_arg;
+ OP_REQUIRES_OK(ctx, ctx->input("target_device", &target_arg));
+ string target_device;
+ OP_REQUIRES_OK(ctx, DeviceNameUtils::CanonicalizeDeviceName(
+ target_arg->scalar<string>()(), source_device,
+ &target_device));
+
+ FunctionLibraryRuntime* lib = ctx->function_library();
+ OP_REQUIRES(ctx, lib != nullptr,
+ errors::Internal("No function library is provided."));
+
+ mutex_lock l(mu_);
+ if (!initialized_) {
+ OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def()));
+ FunctionLibraryRuntime* clone_lib;
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr;
+ OP_REQUIRES_OK(ctx, lib->Clone(&flib_def_, &pflr, &clone_lib));
+ // Create the resource.
+ FunctionBufferingResource* buffer;
+ OP_REQUIRES_OK(
+ ctx,
+ ctx->resource_manager()->LookupOrCreate<FunctionBufferingResource>(
+ cinfo_.container(), cinfo_.name(), &buffer,
+ [clone_lib, &pflr, &source_device, &target_device, func_args,
+ this](FunctionBufferingResource** ptr) {
+ *ptr = new FunctionBufferingResource(
+ clone_lib, std::move(pflr), func_, buffer_size_,
+ source_device, target_device, func_args, output_types_);
+ return Status::OK();
+ }));
+ core::ScopedUnref s(buffer);
+ OP_REQUIRES_OK(ctx, buffer->Instantiate());
+ initialized_ = true;
+ }
+
+ OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
+ ctx, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<FunctionBufferingResource>()));
+ }
+
+ private:
+ mutex mu_;
+ ContainerInfo cinfo_ GUARDED_BY(mu_);
+ bool initialized_ GUARDED_BY(mu_) = false;
+ std::unique_ptr<FunctionLibraryDefinition> flib_def_;
+ NameAttrList func_;
+ int64 buffer_size_;
+ string container_;
+ string name_;
+ DataTypeVector output_types_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
+ .Device(DEVICE_CPU)
+ .HostMemory("resource")
+ .HostMemory("string_arg")
+ .HostMemory("target_device"),
+ FunctionBufferResourceHandleOp);
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
+ .Device(DEVICE_GPU)
+ .HostMemory("resource")
+ .HostMemory("string_arg")
+ .HostMemory("target_device"),
+ FunctionBufferResourceHandleOp);
+#if TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
+ .Device(DEVICE_SYCL)
+ .HostMemory("resource")
+ .HostMemory("string_arg")
+ .HostMemory("target_device"),
+ FunctionBufferResourceHandleOp);
+#endif // TENSORFLOW_USE_SYCL
+
+// Prefetches and fills up a buffer by calling a function that provides the
+// elements to buffer.
+class FunctionBufferingResourceGetNextOp : public AsyncOpKernel {
+ public:
+ explicit FunctionBufferingResourceGetNextOp(OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx) {}
+
+ ~FunctionBufferingResourceGetNextOp() override {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, HandleFromInput(ctx, "function_buffer_resource", &handle), done);
+ FunctionBufferingResource* buffer = nullptr;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, LookupResource<FunctionBufferingResource>(ctx, handle, &buffer),
+ done);
+
+ if (buffer->Finished()) {
+ buffer->Unref();
+ ctx->SetStatus(errors::OutOfRange("end_of_sequence"));
+ done();
+ return;
+ }
+
+ FunctionBufferCallback callback =
+ [ctx, buffer, done](const BufferElement& buffer_element) {
+ Status s = buffer_element.status;
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ buffer->Unref();
+ done();
+ return;
+ }
+ for (size_t i = 0; i < buffer_element.value.size(); ++i) {
+ ctx->set_output(i, buffer_element.value[i]);
+ }
+ buffer->Unref();
+ done();
+ };
+ buffer->MaybeGet(std::move(callback));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
+ .Device(DEVICE_CPU)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceGetNextOp);
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
+ .Device(DEVICE_GPU)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceGetNextOp);
+#if TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
+ .Device(DEVICE_SYCL)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceGetNextOp);
+#endif // TENSORFLOW_USE_SYCL
+
+// Resets the FunctionBufferingResource, cancelling all pending requests and
+// clearing out the buffer.
+class FunctionBufferingResourceResetOp : public OpKernel {
+ public:
+ explicit FunctionBufferingResourceResetOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+ ~FunctionBufferingResourceResetOp() override {}
+
+ void Compute(OpKernelContext* ctx) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(ctx,
+ HandleFromInput(ctx, "function_buffer_resource", &handle));
+ FunctionBufferingResource* buffer = nullptr;
+ OP_REQUIRES_OK(
+ ctx, LookupResource<FunctionBufferingResource>(ctx, handle, &buffer));
+ core::ScopedUnref s(buffer);
+
+ buffer->Reset();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
+ .Device(DEVICE_CPU)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceResetOp);
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
+ .Device(DEVICE_GPU)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceResetOp);
+#if TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
+ .Device(DEVICE_SYCL)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceResetOp);
+#endif // TENSORFLOW_USE_SYCL
+
+class IteratorGetDeviceOp : public OpKernel {
+ public:
+ using OpKernel::OpKernel;
+
+ void Compute(OpKernelContext* ctx) override {
+ // NOTE(mrry): We do not currently Validate that the handle
+ // corresponds to a real IteratorResource, because that symbol is
+ // not exposed from the framework library.
+ Tensor* device_name_t;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &device_name_t));
+ // NOTE(mrry): Since the operation's input is a resource, we must be
+ // colocated with it, and so we can simply return the current device's
+ // name without looking at the input.
+ device_name_t->scalar<string>()() = ctx->device()->name();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIteratorGetDevice").Device(DEVICE_CPU),
+ IteratorGetDeviceOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
index 30fa97a636..c80493d3a1 100644
--- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
@@ -209,10 +209,11 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("ThreadPoolHandle").Device(DEVICE_CPU),
+REGISTER_KERNEL_BUILDER(Name("ExperimentalThreadPoolHandle").Device(DEVICE_CPU),
ThreadPoolHandleOp);
-REGISTER_KERNEL_BUILDER(Name("ThreadPoolDataset").Device(DEVICE_CPU),
- ThreadPoolDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalThreadPoolDataset").Device(DEVICE_CPU),
+ ThreadPoolDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc
index 57fc5697a4..cd612e0eb2 100644
--- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc
@@ -199,8 +199,9 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel {
HANDLE_TYPE(DT_INT64);
HANDLE_TYPE(DT_STRING);
default:
- LOG(FATAL) << "UniqueDataset unhandled data type: "
- << DataTypeString(lhs.dtype());
+ DCHECK(false) << "UniqueDataset unhandled data type: "
+ << DataTypeString(lhs.dtype());
+ return false;
}
}
};
@@ -215,7 +216,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("UniqueDataset").Device(DEVICE_CPU),
+REGISTER_KERNEL_BUILDER(Name("ExperimentalUniqueDataset").Device(DEVICE_CPU),
UniqueDatasetOp);
} // namespace
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index bf0aecaf3c..00884314a9 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -14,11 +14,13 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace data {
@@ -37,14 +39,6 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
FunctionLibraryRuntime::Handle pred_handle;
OP_REQUIRES_OK(ctx,
ctx->function_library()->Instantiate(
@@ -61,9 +55,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
Node* ret_node = pred_body->ret_nodes[0];
Node* ret_input_node;
OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node));
+
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
if (ret_input_node->def().op() == "_Arg") {
int32 index = -1;
@@ -146,7 +141,13 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<FilterDatasetBase> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<FilterDatasetBase>(params) {}
+ : DatasetIterator<FilterDatasetBase>(params),
+ filtered_elements_(0),
+ dropped_elements_(0) {
+ std::vector<string> components =
+ str_util::Split(params.prefix, "::", str_util::SkipEmpty());
+ prefix_end_ = components.back();
+ }
Status Initialize(IteratorContext* ctx) override {
TF_RETURN_IF_ERROR(
@@ -161,6 +162,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
// `input_impl_` and `f` are thread-safe. However, if multiple
// threads enter this method, outputs may be observed in a
// non-deterministic order.
+ auto stats_aggregator = ctx->stats_aggregator();
bool matched;
do {
{
@@ -183,8 +185,34 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
if (!matched) {
// Clear the output tensor list since it didn't match.
out_tensors->clear();
+ if (stats_aggregator) {
+ mutex_lock l(mu_);
+ dropped_elements_++;
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::dropped_elements"),
+ static_cast<float>((dropped_elements_)));
+ // TODO(shivaniagrawal): multiple pipelines would collect
+ // aggregated number of dropped elements for all the pipelines,
+ // exploit tagged_context here.
+ stats_aggregator->IncrementCounter(
+ prefix_end_, "dropped_elements", static_cast<float>(1));
+ }
}
} while (!matched);
+ // TODO(shivaniagrawal): add ratio of dropped_elements and
+ // filtered_elements as a histogram.
+ if (stats_aggregator) {
+ mutex_lock l(mu_);
+ filtered_elements_++;
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::filtered_elements"),
+ static_cast<float>((filtered_elements_)));
+ // TODO(shivaniagrawal): multiple pipelines would collect aggregated
+ // number of filtered elements for all the pipelines, exploit
+ // tagged_context here.
+ stats_aggregator->IncrementCounter(prefix_end_, "filtered_elements",
+ static_cast<float>(1));
+ }
*end_of_sequence = false;
return Status::OK();
}
@@ -197,6 +225,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
else
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_impls_empty"), ""));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("filtered_elements"),
+ filtered_elements_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("dropped_elements"),
+ dropped_elements_));
return Status::OK();
}
@@ -207,12 +239,19 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
input_impl_.reset();
else
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("filtered_elements"),
+ &filtered_elements_));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("dropped_elements"),
+ &dropped_elements_));
return Status::OK();
}
private:
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ int64 filtered_elements_ GUARDED_BY(mu_);
+ int64 dropped_elements_ GUARDED_BY(mu_);
+ string prefix_end_;
};
const DatasetBase* const input_;
diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
index e3c45ef86c..2fada22a21 100644
--- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
@@ -39,18 +39,9 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
-
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output = new Dataset(ctx, input, func_, std::move(captured_func),
output_types_, output_shapes_);
}
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index ac5cc1b2c1..b4367d5a11 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -86,8 +86,6 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
TF_RETURN_IF_ERROR(dataset()->init_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->next_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
- TF_RETURN_IF_ERROR(
- dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
return Status::OK();
}
@@ -96,6 +94,12 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
bool* end_of_sequence) override {
mutex_lock l(mu_);
+ if (!initialized_) {
+ TF_RETURN_IF_ERROR(
+ dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
+ initialized_ = true;
+ }
+
if (finalized_) {
*end_of_sequence = true;
return Status::OK();
@@ -123,6 +127,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
private:
mutex mu_;
+ bool initialized_ GUARDED_BY(mu_) = false;
bool finalized_ GUARDED_BY(mu_) = false;
std::vector<Tensor> state_ GUARDED_BY(mu_);
};
@@ -145,44 +150,18 @@ GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx)
void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx,
DatasetBase** output) {
- OpInputList init_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args",
- &init_func_other_args_input));
- std::vector<Tensor> init_func_other_args;
- init_func_other_args.reserve(init_func_other_args_input.size());
- for (const Tensor& t : init_func_other_args_input) {
- init_func_other_args.push_back(t);
- }
std::unique_ptr<CapturedFunction> init_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(init_func_, std::move(init_func_other_args),
- &init_func));
-
- OpInputList next_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args",
- &next_func_other_args_input));
- std::vector<Tensor> next_func_other_args;
- next_func_other_args.reserve(next_func_other_args_input.size());
- for (const Tensor& t : next_func_other_args_input) {
- next_func_other_args.push_back(t);
- }
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(
+ init_func_, ctx, "init_func_other_args", &init_func));
+
std::unique_ptr<CapturedFunction> next_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(next_func_, std::move(next_func_other_args),
- &next_func));
-
- OpInputList finalize_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args",
- &finalize_func_other_args_input));
- std::vector<Tensor> finalize_func_other_args;
- finalize_func_other_args.reserve(finalize_func_other_args_input.size());
- for (const Tensor& t : finalize_func_other_args_input) {
- finalize_func_other_args.push_back(t);
- }
- std::unique_ptr<CapturedFunction> finalize_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- finalize_func_, std::move(finalize_func_other_args),
- &finalize_func));
+ next_func_, ctx, "next_func_other_args", &next_func));
+
+ std::unique_ptr<CapturedFunction> finalize_func;
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(finalize_func_, ctx,
+ "finalize_func_other_args",
+ &finalize_func));
*output =
new Dataset(ctx, std::move(init_func), std::move(next_func),
diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
index d6ee42a7c6..e7244ee208 100644
--- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
@@ -30,8 +30,7 @@ namespace {
class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
public:
explicit GroupByReducerDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_));
@@ -421,7 +420,6 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList key_func_;
diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
index e4fa557598..14aefe5d54 100644
--- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
@@ -31,8 +31,7 @@ namespace {
class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
public:
explicit GroupByWindowDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size_func", &window_size_func_));
@@ -42,50 +41,19 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- // Get captured inputs for the key, reduce, and window_size functions.
- OpInputList key_func_other_argument_inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("key_func_other_arguments",
- &key_func_other_argument_inputs));
- std::vector<Tensor> key_func_other_arguments;
- key_func_other_arguments.reserve(key_func_other_argument_inputs.size());
- for (const Tensor& t : key_func_other_argument_inputs) {
- key_func_other_arguments.push_back(t);
- }
- OpInputList reduce_func_other_argument_inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("reduce_func_other_arguments",
- &reduce_func_other_argument_inputs));
- std::vector<Tensor> reduce_func_other_arguments;
- reduce_func_other_arguments.reserve(
- reduce_func_other_argument_inputs.size());
- for (const Tensor& t : reduce_func_other_argument_inputs) {
- reduce_func_other_arguments.push_back(t);
- }
- OpInputList window_size_func_other_argument_inputs;
- OP_REQUIRES_OK(ctx,
- ctx->input_list("window_size_func_other_arguments",
- &window_size_func_other_argument_inputs));
- std::vector<Tensor> window_size_func_other_arguments;
- window_size_func_other_arguments.reserve(
- window_size_func_other_argument_inputs.size());
- for (const Tensor& t : window_size_func_other_argument_inputs) {
- window_size_func_other_arguments.push_back(t);
- }
- // TODO(mrry): Refactor CapturedFunction to share the runtime
- // state between multiple functions?
std::unique_ptr<CapturedFunction> captured_key_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- key_func_, std::move(key_func_other_arguments),
- &captured_key_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(key_func_, ctx,
+ "key_func_other_arguments",
+ &captured_key_func));
std::unique_ptr<CapturedFunction> captured_reduce_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(reduce_func_,
- std::move(reduce_func_other_arguments),
- &captured_reduce_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(reduce_func_, ctx,
+ "reduce_func_other_arguments",
+ &captured_reduce_func));
std::unique_ptr<CapturedFunction> captured_window_size_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(
- window_size_func_, std::move(window_size_func_other_arguments),
- &captured_window_size_func));
+ OP_REQUIRES_OK(ctx,
+ CapturedFunction::Create(window_size_func_, ctx,
+ "window_size_func_other_arguments",
+ &captured_window_size_func));
*output = new Dataset(
ctx, input, key_func_, reduce_func_, window_size_func_,
@@ -538,7 +506,6 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList key_func_;
diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc
index 0768f46665..0aa802b874 100644
--- a/tensorflow/core/kernels/data/interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc
@@ -39,14 +39,6 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
const Tensor* cycle_length_t;
OP_REQUIRES_OK(ctx, ctx->input("cycle_length", &cycle_length_t));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(cycle_length_t->shape()),
@@ -66,8 +58,8 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
errors::InvalidArgument("block_length must be greater than zero."));
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output =
new Dataset(ctx, input, func_, std::move(captured_func), cycle_length,
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 30c6585ba2..7a833668ac 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -44,43 +44,6 @@ namespace {
const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
-Status VerifyTypesMatch(const DataTypeVector& expected,
- const DataTypeVector& received) {
- if (expected.size() != received.size()) {
- return errors::InvalidArgument(
- "Number of components does not match: expected ", expected.size(),
- " types but got ", received.size(), ".");
- }
- for (size_t i = 0; i < expected.size(); ++i) {
- if (expected[i] != received[i]) {
- return errors::InvalidArgument("Data type mismatch at component ", i,
- ": expected ", DataTypeString(expected[i]),
- " but got ", DataTypeString(received[i]),
- ".");
- }
- }
- return Status::OK();
-}
-
-Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
- const std::vector<PartialTensorShape>& received) {
- if (expected.size() != received.size()) {
- return errors::InvalidArgument(
- "Number of components does not match: expected ", expected.size(),
- " shapes but got ", received.size(), ".");
- }
- for (size_t i = 0; i < expected.size(); ++i) {
- if (!expected[i].IsCompatibleWith(received[i])) {
- return errors::InvalidArgument("Incompatible shapes at component ", i,
- ": expected ", expected[i].DebugString(),
- " but got ", received[i].DebugString(),
- ".");
- }
- }
-
- return Status::OK();
-}
-
} // namespace
class IteratorResource : public ResourceBase {
@@ -696,6 +659,115 @@ class ToSingleElementOp : public AsyncOpKernel {
BackgroundWorker background_worker_;
};
+class ReduceDatasetOp : public AsyncOpKernel {
+ public:
+ explicit ReduceDatasetOp(OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx),
+ background_worker_(
+ ctx->env(),
+ strings::StrCat("reduce_thread_", SanitizeThreadSuffix(name()))) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &reduce_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
+ &use_inter_op_parallelism_));
+ }
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ // The call to `iterator->GetNext()` may block and depend on an
+ // inter-op thread pool thread, so we issue the call from the
+ // owned thread pool.
+ background_worker_.Schedule([this, ctx, done]() {
+ DatasetBase* dataset;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
+ OpInputList inputs;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("initial_state", &inputs),
+ done);
+ std::vector<Tensor> state(inputs.begin(), inputs.end());
+
+ std::unique_ptr<CapturedFunction> captured_func;
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ CapturedFunction::Create(reduce_func_, ctx, "other_arguments",
+ use_inter_op_parallelism_, &captured_func),
+ done);
+
+ IteratorContext iter_ctx(ctx);
+ OP_REQUIRES_OK_ASYNC(ctx, captured_func->Instantiate(&iter_ctx), done);
+
+ std::unique_ptr<IteratorBase> iterator;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, dataset->MakeIterator(&iter_ctx, "ReduceIterator", &iterator),
+ done);
+
+ // NOTE(jsimsa): We must destroy the iterator before calling `done()`, to
+ // avoid destruction races.
+ IteratorBase* raw_iterator = iterator.release();
+ auto cleanup = gtl::MakeCleanup([raw_iterator, done] {
+ delete raw_iterator;
+ done();
+ });
+
+ // Iterate through the input dataset.
+ Status status;
+ while (true) {
+ std::vector<Tensor> next_input_element;
+ bool end_of_input;
+ status = raw_iterator->GetNext(&iter_ctx, &next_input_element,
+ &end_of_input);
+ if (!status.ok() || end_of_input) {
+ break;
+ }
+
+ // Run the reduce function to update the current state.
+ std::vector<Tensor> args;
+ args.reserve(state.size() + next_input_element.size());
+ std::copy(state.begin(), state.end(), std::back_inserter(args));
+ std::copy(next_input_element.begin(), next_input_element.end(),
+ std::back_inserter(args));
+
+ std::vector<Tensor> reduce_func_output;
+ status =
+ captured_func->Run(&iter_ctx, std::move(args), &reduce_func_output);
+ if (!status.ok()) {
+ break;
+ }
+ std::swap(reduce_func_output, state);
+ }
+
+ if (!status.ok()) {
+ ctx->SetStatus(status);
+ return;
+ }
+ for (int i = 0; i < state.size(); ++i) {
+ OP_REQUIRES_ASYNC(
+ ctx, state[i].dtype() == output_types_[i],
+ errors::InvalidArgument(
+ "The result does not match the expected type for component ", i,
+ ". Expected: ", DataTypeString(output_types_[i]),
+ ". Actual: ", DataTypeString(state[i].dtype()), "."),
+ done);
+ OP_REQUIRES_ASYNC(
+ ctx, output_shapes_[i].IsCompatibleWith(state[i].shape()),
+ errors::InvalidArgument(
+ "The result does not match the expected shape for component ",
+ i, ". Expected: ", output_shapes_[i].DebugString(),
+ ". Actual: ", state[i].shape().DebugString(), "."),
+ done);
+ ctx->set_output(i, state[i]);
+ }
+ });
+ }
+
+ private:
+ NameAttrList reduce_func_;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ bool use_inter_op_parallelism_;
+ BackgroundWorker background_worker_;
+};
+
class OneShotIteratorOp : public AsyncOpKernel {
public:
explicit OneShotIteratorOp(OpKernelConstruction* ctx)
@@ -1183,6 +1255,8 @@ REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE_GPU),
AnonymousIteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
ToSingleElementOp);
+REGISTER_KERNEL_BUILDER(Name("ReduceDataset").Device(DEVICE_CPU),
+ ReduceDatasetOp);
REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
OneShotIteratorOp);
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 85e49355d3..b4c7f9e510 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
+#include <atomic>
#include <utility>
#include "tensorflow/core/common_runtime/function.h"
@@ -26,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/tracing.h"
namespace tensorflow {
@@ -35,11 +37,12 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
+// TODO(b/116852688): Make coordination between the performance model and this
+// transformation more robust.
class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()),
op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
@@ -49,14 +52,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
int64 batch_size;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "batch_size", &batch_size));
OP_REQUIRES(
@@ -77,7 +72,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
case 2:
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
- OP_REQUIRES(ctx, num_parallel_calls > 0,
+ OP_REQUIRES(ctx,
+ num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
break;
@@ -92,8 +88,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
ParseScalarArgument(ctx, "drop_remainder", &drop_remainder));
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output = new Dataset(ctx, input, batch_size, num_parallel_calls,
drop_remainder, output_types_, output_shapes_, func_,
@@ -190,7 +186,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
+ : DatasetIterator<Dataset>(params),
+ num_parallel_calls_(params.dataset->num_parallel_calls_) {}
~Iterator() override {
mutex_lock l(mu_);
@@ -204,8 +201,16 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
- SetMetadata(ctx, "batch_size", dataset()->batch_size_);
- SetMetadata(ctx, "parallelism", dataset()->num_parallel_calls_);
+ mutex_lock l(mu_);
+ AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
+ if (num_parallel_calls_ == kAutoTune) {
+ num_parallel_calls_ = 1;
+ AddTunableParameter(ctx, "parallelism",
+ &num_parallel_calls_ /* value */, 1 /* min */,
+ port::NumSchedulableCPUs() /* max */, &cond_var_);
+ } else {
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ }
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
@@ -220,14 +225,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
EnsureRunnerThreadStarted(ctx);
while (batch_results_.empty() ||
batch_results_.front()->num_calls > 0) {
- StopWork(ctx);
+ RecordStop(ctx);
cond_var_.wait(l);
- StartWork(ctx);
+ RecordStart(ctx);
}
std::swap(result, batch_results_.front());
batch_results_.pop_front();
+ cond_var_.notify_all();
}
- cond_var_.notify_all();
return ProcessResult(ctx, result, out_tensors, end_of_sequence);
}
@@ -330,11 +335,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
void CallCompleted(const std::shared_ptr<BatchResult>& result)
LOCKS_EXCLUDED(mu_) {
- {
- mutex_lock l(mu_);
- num_calls_--;
- result->num_calls--;
- }
+ mutex_lock l(mu_);
+ num_calls_--;
+ result->num_calls--;
cond_var_.notify_all();
}
@@ -427,11 +430,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
result->output_allocated = true;
}
- int MaxBatchResults() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- return (dataset()->num_parallel_calls_ + dataset()->batch_size_ - 1) /
- dataset()->batch_size_;
- }
-
Status ProcessResult(IteratorContext* ctx,
const std::shared_ptr<BatchResult>& result,
std::vector<Tensor>* out_tensors,
@@ -480,31 +478,34 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
LOCKS_EXCLUDED(mu_) {
std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls;
- new_calls.reserve(dataset()->num_parallel_calls_);
- StartWork(ctx.get());
+ RecordStart(ctx.get());
auto stop_cleanup =
- gtl::MakeCleanup([this, &ctx]() { StopWork(ctx.get()); });
+ gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); });
+ new_calls.reserve(num_parallel_calls_);
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
+ int64 num_parallel_calls = num_parallel_calls_;
+ int64 max_batch_results =
+ (num_parallel_calls + dataset()->batch_size_ - 1) /
+ dataset()->batch_size_;
+ return num_calls_ >= num_parallel_calls ||
+ (batch_results_.size() > max_batch_results ||
+ (batch_results_.size() == max_batch_results &&
+ call_counter_ % dataset()->batch_size_ == 0));
+ };
while (true) {
{
mutex_lock l(mu_);
- while (!cancelled_ &&
- (num_calls_ >= dataset()->num_parallel_calls_ ||
- batch_results_.size() > MaxBatchResults() ||
- (batch_results_.size() == MaxBatchResults() &&
- call_counter_ % dataset()->batch_size_ == 0))) {
- StopWork(ctx.get());
+ while (!cancelled_ && busy()) {
+ RecordStop(ctx.get());
cond_var_.wait(l);
- StartWork(ctx.get());
+ RecordStart(ctx.get());
}
if (cancelled_) {
return;
}
- while (num_calls_ < dataset()->num_parallel_calls_ &&
- (batch_results_.size() < MaxBatchResults() ||
- (batch_results_.size() == MaxBatchResults() &&
- call_counter_ % dataset()->batch_size_ != 0))) {
+ while (!busy()) {
if (call_counter_ % dataset()->batch_size_ == 0) {
batch_results_.emplace_back(
new BatchResult(dataset()->batch_size_));
@@ -648,6 +649,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
// user specified level of parallelism and there are slots available in
// the `batch_results_` buffer.
condition_variable cond_var_;
+ // Identifies the maximum number of parallel calls.
+ std::atomic<int64> num_parallel_calls_;
// Counts the number of outstanding calls for this batch.
int64 num_calls_ GUARDED_BY(mu_) = 0;
// Counts the total number of calls.
@@ -671,7 +674,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const Eigen::ThreadPoolDevice* device_; // not owned
};
- const int graph_def_version_;
const int op_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index af301e2b42..f112e1dc43 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -38,18 +38,10 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments),
- use_inter_op_parallelism_, &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ use_inter_op_parallelism_,
+ &captured_func));
*output = new Dataset(ctx, input, func_, std::move(captured_func),
output_types_, output_shapes_);
diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc
index c7f929dbc1..9aa505f4f1 100644
--- a/tensorflow/core/kernels/data/model_dataset_op.cc
+++ b/tensorflow/core/kernels/data/model_dataset_op.cc
@@ -17,11 +17,14 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/cpu_info.h"
namespace tensorflow {
namespace data {
namespace {
+const int kOptimizationPeriodThresholdMs = 60 * EnvTime::kSecondsToMicros;
+
class ModelDatasetOp : public UnaryDatasetOpKernel {
public:
explicit ModelDatasetOp(OpKernelConstruction* ctx)
@@ -71,9 +74,16 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params), model_(new model::Model()) {}
+ : DatasetIterator<Dataset>(params),
+ model_(std::make_shared<model::Model>()) {}
- ~Iterator() override { model_->OutputToFile(); }
+ ~Iterator() override {
+ // Signal the optimize thread to terminate it. We will then join that
+ // thread when we delete `this->optimize_thread_`.
+ mutex_lock l(mu_);
+ cancelled_ = true;
+ cond_var_.notify_all();
+ }
Status Initialize(IteratorContext* ctx) override {
IteratorContext ctx_with_model(CreateParams(ctx));
@@ -85,6 +95,7 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(EnsureOptimizeThreadStarted(ctx));
IteratorContext ctx_with_model(CreateParams(ctx));
return input_impl_->GetNext(&ctx_with_model, out_tensors,
end_of_sequence);
@@ -111,8 +122,53 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
}
private:
+ Status EnsureOptimizeThreadStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!optimize_thread_) {
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
+ optimize_thread_.reset(ctx->env()->StartThread(
+ {}, "optimize_thread",
+ [this, new_ctx]() { OptimizeThread(new_ctx); }));
+ }
+ return Status::OK();
+ }
+
+ void OptimizeThread(const std::shared_ptr<IteratorContext>& ctx) {
+ int64 last_optimization_ms = 0;
+ int64 optimization_period_ms = 10;
+ while (true) {
+ {
+ mutex_lock l(mu_);
+ while (!cancelled_ &&
+ last_optimization_ms + optimization_period_ms >=
+ ctx->env()->NowMicros() / EnvTime::kMillisToMicros) {
+ cond_var_.wait_for(
+ l, std::chrono::milliseconds(
+ last_optimization_ms + optimization_period_ms -
+ ctx->env()->NowMicros() / EnvTime::kMillisToMicros));
+ }
+ if (cancelled_) return;
+ }
+ model_->Optimize(port::NumSchedulableCPUs());
+ // Exponentially increase the period of running the optimization
+ // until a threshold is reached.
+ if (optimization_period_ms < kOptimizationPeriodThresholdMs) {
+ if (optimization_period_ms << 1 < kOptimizationPeriodThresholdMs) {
+ optimization_period_ms <<= 1;
+ } else {
+ optimization_period_ms = kOptimizationPeriodThresholdMs;
+ }
+ }
+ last_optimization_ms =
+ ctx->env()->NowMicros() / EnvTime::kMillisToMicros;
+ }
+ }
+
mutex mu_;
+ condition_variable cond_var_;
std::shared_ptr<model::Model> model_;
+ std::unique_ptr<Thread> optimize_thread_ GUARDED_BY(mu_);
+ bool cancelled_ GUARDED_BY(mu_) = false;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
};
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
index 078de717e0..d909b9e9d3 100644
--- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc
+++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
+#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/util/device_name_utils.h"
@@ -27,506 +29,6 @@ namespace tensorflow {
namespace data {
namespace {
-struct BufferElement {
- // The producer sets `status` if getting the input element fails.
- Status status;
- // The buffered data element.
- std::vector<Tensor> value;
-};
-
-using FunctionBufferCallback = std::function<void(const BufferElement&)>;
-
-class FunctionBufferingResource : public ResourceBase {
- public:
- FunctionBufferingResource(FunctionLibraryRuntime* lib,
- std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
- const NameAttrList& func, int64 buffer_size,
- const string& source_device,
- const string& target_device,
- const std::vector<Tensor>& func_args,
- const DataTypeVector& output_types)
- : lib_(lib),
- pflr_(std::move(pflr)),
- func_(func),
- buffer_size_(buffer_size),
- source_device_(source_device),
- target_device_(target_device),
- func_args_(func_args),
- output_types_(output_types),
- handle_(kInvalidHandle),
- is_buffering_(false),
- end_of_sequence_(false),
- cancelled_(false) {}
-
- ~FunctionBufferingResource() override {
- Cancel();
- }
-
- string DebugString() override {
- return strings::StrCat("FunctionBufferingResource. Size: ", buffer_size_,
- "; target_device: ", target_device_);
- }
-
- // Instantiates the function the first time it's called. After that it caches
- // the handle.
- Status Instantiate() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- // Re-use existing handle if it's been set, effectively caching it.
- if (handle_ != kInvalidHandle) {
- return Status::OK();
- }
- AttrValueMap attr_values = func_.attr();
- FunctionLibraryRuntime::InstantiateOptions opts;
- opts.target = target_device_;
- return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), opts,
- &handle_);
- }
-
- // Returns true if we've got to the end of the sequence and exhausted the
- // buffer.
- bool Finished() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- return end_of_sequence_ && buffer_.empty();
- }
-
- // Cancels any buffering / prefetching going on.
- void Cancel() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- cancelled_ = true;
- while (is_buffering_) {
- cond_var_.wait(l);
- }
- }
-
- // Cancels all pending operations and then clears out the state.
- void Reset() LOCKS_EXCLUDED(mu_) {
- Cancel();
- mutex_lock l(mu_);
- buffer_.clear();
- requests_.clear();
- is_buffering_ = false;
- end_of_sequence_ = false;
- cancelled_ = false;
- }
-
- // If the buffer has anything, runs `callback` on the first element in the
- // buffer, else schedules the `callback` to be called. Requires `args` and
- // `lib` in case more function calls need to be scheduled.
- void MaybeGet(FunctionBufferCallback callback) LOCKS_EXCLUDED(mu_) {
- bool start_buffering = false;
- bool produced_output = false;
- BufferElement buffer_element;
- {
- mutex_lock l(mu_);
- if (!is_buffering_ && !end_of_sequence_) {
- start_buffering = true;
- }
- if (!buffer_.empty()) {
- produced_output = true;
- std::swap(buffer_element, buffer_.front());
- buffer_.pop_front();
- } else {
- produced_output = false;
- requests_.push_back(std::move(callback));
- }
- }
- if (produced_output) {
- callback(buffer_element);
- }
- if (start_buffering) {
- FillBuffer();
- }
- }
-
- private:
- void FillBuffer() LOCKS_EXCLUDED(mu_) {
- FunctionLibraryRuntime::Handle handle;
- std::vector<FunctionBufferCallback> cancellation_callbacks;
- std::vector<BufferElement> cancellation_buffer_elements;
- bool cancelled = false;
- {
- mutex_lock l(mu_);
- handle = handle_;
- if (cancelled_) {
- cancelled = true;
- // Run through and fulfill all pending requests, if possible.
- while (!requests_.empty()) {
- if (!buffer_.empty()) {
- cancellation_buffer_elements.push_back(std::move(buffer_.front()));
- buffer_.pop_front();
- cancellation_callbacks.push_back(std::move(requests_.front()));
- requests_.pop_front();
- } else {
- LOG(ERROR) << "Buffer ran out of elements and we couldn't satisfy: "
- << requests_.size() << " requests";
- break;
- }
- }
- is_buffering_ = false;
- } else {
- is_buffering_ = true;
- }
- }
- if (cancelled) {
- for (int i = 0; i < cancellation_callbacks.size(); ++i) {
- cancellation_callbacks[i](cancellation_buffer_elements[i]);
- }
- cond_var_.notify_all();
- return;
- }
- FunctionLibraryRuntime::Options opts;
- // Copied from CapturedFunction::generate_step_id();
- opts.step_id = -std::abs(static_cast<int64>(random::New64()));
- opts.source_device = source_device_;
- AllocatorAttributes arg_alloc_attr;
- arg_alloc_attr.set_on_host(true);
- opts.args_alloc_attrs.push_back(arg_alloc_attr);
- for (const auto& dtype : output_types_) {
- AllocatorAttributes ret_alloc_attrs;
- if (DataTypeAlwaysOnHost(dtype)) {
- ret_alloc_attrs.set_on_host(true);
- }
- opts.rets_alloc_attrs.push_back(ret_alloc_attrs);
- }
- if (opts.source_device != target_device_) {
- opts.remote_execution = true;
- }
- opts.create_rendezvous = true;
- auto* rets = new std::vector<Tensor>;
- lib_->Run(opts, handle, func_args_, rets,
- [this, rets](const Status& status) {
- FunctionBufferCallback callback = nullptr;
- BufferElement buffer_front;
- bool restart_buffering = false;
- {
- mutex_lock l(mu_);
- BufferElement buffer_element;
- buffer_element.status = status;
- if (status.ok()) {
- buffer_element.value.swap(*rets);
- } else {
- end_of_sequence_ = true;
- is_buffering_ = false;
- }
- buffer_.push_back(std::move(buffer_element));
- if (!requests_.empty()) {
- buffer_front = std::move(buffer_.front());
- buffer_.pop_front();
- callback = std::move(requests_.front());
- requests_.pop_front();
- }
- if (buffer_.size() < buffer_size_ && !end_of_sequence_) {
- restart_buffering = true;
- } else {
- // When the buffer is full, we don't want to call
- // FillBuffer() unless we're in cancellation phase in which
- // case FillBuffer() will do the final cleanup post
- // cancellation.
- if (cancelled_) {
- restart_buffering = true;
- }
- is_buffering_ = false;
- }
- }
- if (callback != nullptr) {
- callback(buffer_front);
- }
- if (restart_buffering) {
- FillBuffer();
- }
- });
- }
-
- mutex mu_;
- FunctionLibraryRuntime* lib_;
- std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
- NameAttrList func_;
- const int64 buffer_size_;
- const string source_device_;
- const string target_device_;
- const std::vector<Tensor> func_args_;
- const DataTypeVector output_types_;
- FunctionLibraryRuntime::Handle handle_ GUARDED_BY(mu_);
- std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
- std::deque<FunctionBufferCallback> requests_ GUARDED_BY(mu_);
- bool is_buffering_ GUARDED_BY(mu_);
- bool end_of_sequence_ GUARDED_BY(mu_);
- bool cancelled_ GUARDED_BY(mu_);
- condition_variable cond_var_;
-};
-
-class FunctionBufferResourceHandleOp : public OpKernel {
- public:
- explicit FunctionBufferResourceHandleOp(OpKernelConstruction* ctx)
- : OpKernel(ctx), flib_def_(nullptr) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
- }
-
- ~FunctionBufferResourceHandleOp() override {
- if (cinfo_.resource_is_private_to_kernel()) {
- if (!cinfo_.resource_manager()
- ->Delete<FunctionBufferingResource>(cinfo_.container(),
- cinfo_.name())
- .ok()) {
- // Do nothing; the resource can have been deleted by session resets.
- }
- }
- }
-
- void Compute(OpKernelContext* ctx) override {
- const Tensor* string_arg;
- OP_REQUIRES_OK(ctx, ctx->input("string_arg", &string_arg));
- std::vector<Tensor> func_args;
- func_args.push_back(*string_arg);
-
- const string& source_device = ctx->device()->name();
-
- // Obtain and canonicalize target_device.
- const Tensor* target_arg;
- OP_REQUIRES_OK(ctx, ctx->input("target_device", &target_arg));
- string target_device;
- OP_REQUIRES_OK(ctx, DeviceNameUtils::CanonicalizeDeviceName(
- target_arg->scalar<string>()(), source_device,
- &target_device));
-
- FunctionLibraryRuntime* lib = ctx->function_library();
- OP_REQUIRES(ctx, lib != nullptr,
- errors::Internal("No function library is provided."));
-
- mutex_lock l(mu_);
- if (!initialized_) {
- OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def()));
- FunctionLibraryRuntime* clone_lib;
- std::unique_ptr<ProcessFunctionLibraryRuntime> pflr;
- OP_REQUIRES_OK(ctx, lib->Clone(&flib_def_, &pflr, &clone_lib));
- // Create the resource.
- FunctionBufferingResource* buffer;
- OP_REQUIRES_OK(
- ctx,
- ctx->resource_manager()->LookupOrCreate<FunctionBufferingResource>(
- cinfo_.container(), cinfo_.name(), &buffer,
- [clone_lib, &pflr, &source_device, &target_device, func_args,
- this](FunctionBufferingResource** ptr) {
- *ptr = new FunctionBufferingResource(
- clone_lib, std::move(pflr), func_, buffer_size_,
- source_device, target_device, func_args, output_types_);
- return Status::OK();
- }));
- core::ScopedUnref s(buffer);
- OP_REQUIRES_OK(ctx, buffer->Instantiate());
- initialized_ = true;
- }
-
- OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
- ctx, 0, cinfo_.container(), cinfo_.name(),
- MakeTypeIndex<FunctionBufferingResource>()));
- }
-
- private:
- mutex mu_;
- ContainerInfo cinfo_ GUARDED_BY(mu_);
- bool initialized_ GUARDED_BY(mu_) = false;
- std::unique_ptr<FunctionLibraryDefinition> flib_def_;
- NameAttrList func_;
- int64 buffer_size_;
- string container_;
- string name_;
- DataTypeVector output_types_;
-};
-
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource")
- .Device(DEVICE_CPU)
- .HostMemory("resource")
- .HostMemory("string_arg")
- .HostMemory("target_device"),
- FunctionBufferResourceHandleOp);
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource")
- .Device(DEVICE_GPU)
- .HostMemory("resource")
- .HostMemory("string_arg")
- .HostMemory("target_device"),
- FunctionBufferResourceHandleOp);
-#if TENSORFLOW_USE_SYCL
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource")
- .Device(DEVICE_SYCL)
- .HostMemory("resource")
- .HostMemory("string_arg")
- .HostMemory("target_device"),
- FunctionBufferResourceHandleOp);
-#endif // TENSORFLOW_USE_SYCL
-
-// Prefetches and fills up a buffer by calling a function that provides the
-// elements to buffer.
-class FunctionBufferingResourceGetNextOp : public AsyncOpKernel {
- public:
- explicit FunctionBufferingResourceGetNextOp(OpKernelConstruction* ctx)
- : AsyncOpKernel(ctx) {}
-
- ~FunctionBufferingResourceGetNextOp() override {}
-
- void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
- ResourceHandle handle;
- OP_REQUIRES_OK_ASYNC(
- ctx, HandleFromInput(ctx, "function_buffer_resource", &handle), done);
- FunctionBufferingResource* buffer = nullptr;
- OP_REQUIRES_OK_ASYNC(
- ctx, LookupResource<FunctionBufferingResource>(ctx, handle, &buffer),
- done);
-
- if (buffer->Finished()) {
- buffer->Unref();
- ctx->SetStatus(errors::OutOfRange("end_of_sequence"));
- done();
- return;
- }
-
- FunctionBufferCallback callback =
- [ctx, buffer, done](const BufferElement& buffer_element) {
- Status s = buffer_element.status;
- if (!s.ok()) {
- ctx->SetStatus(s);
- buffer->Unref();
- done();
- return;
- }
- for (size_t i = 0; i < buffer_element.value.size(); ++i) {
- ctx->set_output(i, buffer_element.value[i]);
- }
- buffer->Unref();
- done();
- };
- buffer->MaybeGet(std::move(callback));
- }
-};
-
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext")
- .Device(DEVICE_CPU)
- .HostMemory("function_buffer_resource"),
- FunctionBufferingResourceGetNextOp);
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext")
- .Device(DEVICE_GPU)
- .HostMemory("function_buffer_resource"),
- FunctionBufferingResourceGetNextOp);
-#if TENSORFLOW_USE_SYCL
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext")
- .Device(DEVICE_SYCL)
- .HostMemory("function_buffer_resource"),
- FunctionBufferingResourceGetNextOp);
-#endif // TENSORFLOW_USE_SYCL
-
-// Resets the FunctionBufferingResource, cancelling all pending requests and
-// clearing out the buffer.
-class FunctionBufferingResourceResetOp : public OpKernel {
- public:
- explicit FunctionBufferingResourceResetOp(OpKernelConstruction* ctx)
- : OpKernel(ctx) {}
-
- ~FunctionBufferingResourceResetOp() override {}
-
- void Compute(OpKernelContext* ctx) override {
- ResourceHandle handle;
- OP_REQUIRES_OK(ctx,
- HandleFromInput(ctx, "function_buffer_resource", &handle));
- FunctionBufferingResource* buffer = nullptr;
- OP_REQUIRES_OK(
- ctx, LookupResource<FunctionBufferingResource>(ctx, handle, &buffer));
- core::ScopedUnref s(buffer);
-
- buffer->Reset();
- }
-};
-
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset")
- .Device(DEVICE_CPU)
- .HostMemory("function_buffer_resource"),
- FunctionBufferingResourceResetOp);
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset")
- .Device(DEVICE_GPU)
- .HostMemory("function_buffer_resource"),
- FunctionBufferingResourceResetOp);
-#if TENSORFLOW_USE_SYCL
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset")
- .Device(DEVICE_SYCL)
- .HostMemory("function_buffer_resource"),
- FunctionBufferingResourceResetOp);
-#endif // TENSORFLOW_USE_SYCL
-
-class IteratorGetDeviceOp : public OpKernel {
- public:
- using OpKernel::OpKernel;
-
- void Compute(OpKernelContext* ctx) override {
- // NOTE(mrry): We do not currently Validate that the handle
- // corresponds to a real IteratorResource, because that symbol is
- // not exposed from the framework library.
- Tensor* device_name_t;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_output(0, TensorShape({}), &device_name_t));
- // NOTE(mrry): Since the operation's input is a resource, we must be
- // colocated with it, and so we can simply return the current device's
- // name without looking at the input.
- device_name_t->scalar<string>()() = ctx->device()->name();
- }
-};
-
-REGISTER_KERNEL_BUILDER(Name("IteratorGetDevice").Device(DEVICE_CPU),
- IteratorGetDeviceOp);
-
-Status VerifyTypesMatch(const DataTypeVector& expected,
- const DataTypeVector& received) {
- if (expected.size() != received.size()) {
- return errors::InvalidArgument(
- "Number of components does not match: expected ", expected.size(),
- " types but got ", received.size(), ".");
- }
- for (size_t i = 0; i < expected.size(); ++i) {
- if (expected[i] != received[i]) {
- return errors::InvalidArgument("Data type mismatch at component ", i,
- ": expected ", DataTypeString(expected[i]),
- " but got ", DataTypeString(received[i]),
- ".");
- }
- }
- return Status::OK();
-}
-
-Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
- const std::vector<PartialTensorShape>& received) {
- if (expected.size() != received.size()) {
- return errors::InvalidArgument(
- "Number of components does not match: expected ", expected.size(),
- " shapes but got ", received.size(), ".");
- }
- for (size_t i = 0; i < expected.size(); ++i) {
- if (!expected[i].IsCompatibleWith(received[i])) {
- return errors::InvalidArgument("Incompatible shapes at component ", i,
- ": expected ", expected[i].DebugString(),
- " but got ", received[i].DebugString(),
- ".");
- }
- }
-
- return Status::OK();
-}
-
-string SanitizeThreadSuffix(string suffix) {
- string clean;
- for (int i = 0; i < suffix.size(); ++i) {
- const char ch = suffix[i];
- if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') ||
- (ch >= '0' && ch <= '9') || ch == '_' || ch == '-') {
- clean += ch;
- } else {
- clean += '_';
- }
- }
- return clean;
-}
-
struct HostBufferElement {
Status status;
bool end_of_sequence;
@@ -550,7 +52,7 @@ class MultiDeviceIterator : public ResourceBase {
flib_def_(std::move(flib_def)),
pflr_(std::move(pflr)),
lib_(lib) {
- CHECK_NOTNULL(lib_);
+ DCHECK(lib_ != nullptr);
}
string DebugString() override {
@@ -621,24 +123,28 @@ class MultiDeviceIterator : public ResourceBase {
incarnation_id_(incarnation_id),
host_iterator_(std::move(host_iterator)) {}
- ~MultiDeviceBuffer() { Reset(); }
+ ~MultiDeviceBuffer() {
+ {
+ mutex_lock l(mu_);
+ if (!background_thread_started_) return;
+ }
+ Reset();
+ }
void Reset() LOCKS_EXCLUDED(mu_) {
{
mutex_lock l(mu_);
- if (background_thread_finished_) {
- return;
- }
-
- cancelled_ = true;
- // Wake up the background thread.
- for (int i = 0; i < size_; ++i) {
- buffer_[i].cond_var.notify_all();
- }
+ if (!background_thread_finished_) {
+ cancelled_ = true;
+ // Wake up the background thread.
+ for (int i = 0; i < size_; ++i) {
+ buffer_[i].cond_var.notify_all();
+ }
- // Make sure background thread has finished first.
- while (!background_thread_finished_) {
- shutdown_cond_var_.wait(l);
+ // Make sure background thread has finished first.
+ while (!background_thread_finished_) {
+ shutdown_cond_var_.wait(l);
+ }
}
}
RunPendingCallbacks();
@@ -674,7 +180,7 @@ class MultiDeviceIterator : public ResourceBase {
buffer_[shard_num].cond_var.notify_all();
}
} else {
- if (background_thread_finished_) {
+ if (end_of_iterator_) {
produced_output = true;
elem.end_of_sequence = true;
} else {
@@ -711,8 +217,12 @@ class MultiDeviceIterator : public ResourceBase {
while (!buffer_[i].callbacks.empty()) {
if (buffer_[i].data.empty()) {
HostBufferElement elem;
- elem.status =
- errors::Cancelled("Cancelled and buffer not filled.");
+ if (end_of_iterator_) {
+ elem.end_of_sequence = true;
+ } else {
+ elem.status =
+ errors::Cancelled("Cancelled and buffer not filled.");
+ }
cancellation_elements.push_back(std::move(elem));
} else {
cancellation_elements.push_back(
@@ -731,6 +241,10 @@ class MultiDeviceIterator : public ResourceBase {
}
void BackgroundThread(IteratorContext* ctx) {
+ {
+ mutex_lock l(mu_);
+ background_thread_started_ = true;
+ }
std::unique_ptr<IteratorContext> cleanup(ctx);
int shard_to_fetch = 0;
while (true) {
@@ -781,6 +295,7 @@ class MultiDeviceIterator : public ResourceBase {
{
mutex_lock l(mu_);
background_thread_finished_ = true;
+ end_of_iterator_ = true;
shutdown_cond_var_.notify_all();
}
RunPendingCallbacks();
@@ -799,6 +314,8 @@ class MultiDeviceIterator : public ResourceBase {
mutex mu_;
std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_);
bool background_thread_finished_ GUARDED_BY(mu_) = false;
+ bool background_thread_started_ GUARDED_BY(mu_) = false;
+ bool end_of_iterator_ GUARDED_BY(mu_) = false;
bool cancelled_ GUARDED_BY(mu_) = false;
condition_variable shutdown_cond_var_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc
index 6180df5af2..2ab5c83082 100644
--- a/tensorflow/core/kernels/data/optional_ops.cc
+++ b/tensorflow/core/kernels/data/optional_ops.cc
@@ -108,11 +108,8 @@ class OptionalFromValueOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
OpInputList components_input;
OP_REQUIRES_OK(ctx, ctx->input_list("components", &components_input));
- std::vector<Tensor> components;
- components.reserve(components_input.size());
- for (const Tensor& component_t : components_input) {
- components.push_back(component_t);
- }
+ std::vector<Tensor> components(components_input.begin(),
+ components_input.end());
OP_REQUIRES_OK(
ctx, WriteOptionalWithValueToOutput(ctx, 0, std::move(components)));
}
@@ -216,6 +213,14 @@ static Status OptionalDeviceCopy(
std::vector<Tensor> to_values;
to_values.reserve(from_values.size());
for (const Tensor& t : from_values) {
+ if (t.dtype() == DT_VARIANT) {
+ // TODO(b/116349787): Implement support for nested variants.
+ return errors::Unimplemented(
+ "Support for copying nested variants to device has not yet been "
+ "implemented.");
+ }
+ }
+ for (const Tensor& t : from_values) {
if (DMAHelper::CanUseDMA(&t)) {
Tensor tmp(t.dtype());
TF_RETURN_IF_ERROR(copy(t, &tmp));
diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
index 73eeafd797..7b01c3b4e0 100644
--- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
@@ -207,7 +207,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- SetMetadata(ctx, "batch_size", dataset()->batch_size_);
+ AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index aa5e613e24..2bb38bf0b9 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <atomic>
#include <deque>
#include <utility>
@@ -44,14 +45,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
int64 cycle_length = 0;
OP_REQUIRES_OK(ctx,
ParseScalarArgument(ctx, "cycle_length", &cycle_length));
@@ -83,8 +76,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(
- interleave_func_, std::move(other_arguments), &captured_func));
+ ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments",
+ &captured_func));
*output =
new Dataset(ctx, input, interleave_func_, std::move(captured_func),
@@ -252,7 +245,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
- SetMetadata(ctx, "parallelism", dataset()->cycle_length_);
+ AddConstantParameter(ctx, "parallelism", dataset()->cycle_length_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
@@ -352,13 +345,13 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (must_wait_for_input) {
// Wait for elements to become available.
- StopWork(ctx);
+ RecordStop(ctx);
if (dataset()->sloppy_) {
sloppy_cond_var_.wait(l);
} else {
workers_[interleave_indices_[next_index_]].cond_var.wait(l);
}
- StartWork(ctx);
+ RecordStart(ctx);
}
}
return errors::Cancelled(
@@ -626,11 +619,11 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// std::function arguments are copy-constructable, so we pass raw
// pointers, and then immediately wrap them to ensure correct ownership.
- StartWork(ctx.get());
+ RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
mutex_lock l(mu_);
workers_[thread_index].cond_var.notify_all();
- StopWork(ctx.get());
+ RecordStop(ctx.get());
});
bool make_new_iterator;
{
@@ -668,9 +661,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (read_new_input) {
mutex_lock l(mu_);
while (!cancelled_ && !workers_[thread_index].is_producing) {
- StopWork(ctx.get());
+ RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
- StartWork(ctx.get());
+ RecordStart(ctx.get());
}
if (cancelled_) return;
// Copy the input tensors so that we do not need to block on `mu_`
@@ -720,9 +713,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
- StopWork(ctx.get());
+ RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
- StartWork(ctx.get());
+ RecordStart(ctx.get());
}
if (cancelled_) return;
tf_shared_lock ckpt_l(ckpt_mu_);
@@ -771,9 +764,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
- StopWork(ctx.get());
+ RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
- StartWork(ctx.get());
+ RecordStart(ctx.get());
}
if (cancelled_) return;
@@ -1091,6 +1084,9 @@ REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
// The above design choices were made with automated optimizations in mind,
// isolating the degree of parallelism as the single tunable knob of this
// implementation.
+//
+// TODO(b/116852688): Make coordination between the performance model and this
+// transformation more robust.
class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
public:
explicit ParallelInterleaveDatasetV2Op(OpKernelConstruction* ctx)
@@ -1102,9 +1098,6 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
-
int64 cycle_length = 0;
OP_REQUIRES_OK(ctx,
ParseScalarArgument(ctx, "cycle_length", &cycle_length));
@@ -1120,7 +1113,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
int64 num_parallel_calls;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
- OP_REQUIRES(ctx, num_parallel_calls > 0,
+ OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
OP_REQUIRES(
@@ -1128,16 +1121,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
errors::InvalidArgument(
"num_parallel_calls must less than or equal to cycle_length."));
- // TODO(b/114267189): Use `other_arguments(inputs.begin(), inputs.end());`.
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(
- interleave_func_, std::move(other_arguments), &captured_func));
+ ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments",
+ &captured_func));
*output = new Dataset(ctx, input, interleave_func_,
std::move(captured_func), cycle_length, block_length,
@@ -1230,6 +1217,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
+ num_parallel_calls_(params.dataset->num_parallel_calls_),
args_list_(params.dataset->cycle_length_),
current_elements_(params.dataset->cycle_length_),
element_in_use_(params.dataset->cycle_length_, false),
@@ -1250,7 +1238,16 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
- SetMetadata(ctx, "parallelism", dataset()->num_parallel_calls_);
+ mutex_lock l(mu_);
+ if (num_parallel_calls_ == kAutoTune) {
+ num_parallel_calls_ = 1;
+ AddTunableParameter(ctx, "parallelism",
+ &num_parallel_calls_ /* value */, 1 /* min */,
+ dataset()->cycle_length_ /* max */, &cond_var_);
+ } else {
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ }
+ AddConstantParameter(ctx, "cycle_length", dataset()->cycle_length_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
@@ -1266,9 +1263,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty() &&
(!end_of_input_ || num_open_ > 0)) {
- StopWork(ctx);
+ RecordStop(ctx);
cond_var_.wait(l);
- StartWork(ctx);
+ RecordStart(ctx);
}
if (!invocation_results_.empty()) {
std::swap(result, invocation_results_.front());
@@ -1277,11 +1274,11 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
*end_of_sequence = true;
return Status::OK();
}
+ cond_var_.notify_all();
}
- cond_var_.notify_all();
- StopWork(ctx);
+ RecordStop(ctx);
result->notification.WaitForNotification();
- StartWork(ctx);
+ RecordStart(ctx);
} while (result->skip);
if (result->status.ok()) {
@@ -1405,8 +1402,8 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index,
const std::vector<std::shared_ptr<InvocationResult>>& results)
LOCKS_EXCLUDED(mu_) {
- StartWork(ctx.get());
- auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
bool end_of_input = false;
for (auto& result : results) {
if (!end_of_input) {
@@ -1424,60 +1421,66 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
// Release the ownership of the cycle element iterator, closing the
// iterator if end of input was encountered.
- {
- if (end_of_input) {
- current_elements_[cycle_index].reset();
- }
- mutex_lock l(mu_);
- element_in_use_[cycle_index] = false;
- num_calls_--;
- if (end_of_input) {
- args_list_[cycle_index].clear();
- num_open_--;
- }
+ if (end_of_input) {
+ current_elements_[cycle_index].reset();
+ }
+ mutex_lock l(mu_);
+ element_in_use_[cycle_index] = false;
+ num_calls_--;
+ if (end_of_input) {
+ args_list_[cycle_index].clear();
+ num_open_--;
}
cond_var_.notify_all();
}
- int64 MaxInvocationResults() {
- return dataset()->cycle_length_ * dataset()->block_length_;
- }
-
// Method responsible for 1) creating iterators out of input elements, 2)
// determining the order in which elements are fetched from the iterators,
// and 3) scheduling the fetching of the elements to a threadpool.
//
// This method runs in the `runner_thread` background thread.
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
- StartWork(ctx.get());
- auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
+ return element_in_use_[cycle_index_] ||
+ num_calls_ >= num_parallel_calls_ ||
+ invocation_results_.size() >=
+ dataset()->cycle_length_ * dataset()->block_length_;
+ };
while (true) {
- {
- mutex_lock l(mu_);
- // Wait until this thread is cancelled, the end of input has been
- // reached, or the cycle element at the `cycle_index_` position is
- // not in use and there is space in the `invocation_results_` queue.
- while (!cancelled_ && (!end_of_input_ || num_open_ > 0) &&
- (element_in_use_[cycle_index_] ||
- num_calls_ >= dataset()->num_parallel_calls_ ||
- invocation_results_.size() >= MaxInvocationResults())) {
- StopWork(ctx.get());
- cond_var_.wait(l);
- StartWork(ctx.get());
- }
+ mutex_lock l(mu_);
+ // Wait until this thread is cancelled, the end of input has been
+ // reached, or the cycle element at the `cycle_index_` position is
+ // not in use and there is space in the `invocation_results_` queue.
+ while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && busy()) {
+ RecordStop(ctx.get());
+ cond_var_.wait(l);
+ RecordStart(ctx.get());
+ }
- if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
- return;
- }
+ if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
+ return;
+ }
- while (!element_in_use_[cycle_index_] &&
- (!end_of_input_ || num_open_ > 0) &&
- num_calls_ < dataset()->num_parallel_calls_ &&
- invocation_results_.size() < MaxInvocationResults()) {
- if (!current_elements_[cycle_index_]) {
- // Try to create a new iterator from the next input element.
- Status status = input_impl_->GetNext(
- ctx.get(), &args_list_[cycle_index_], &end_of_input_);
+ while ((!end_of_input_ || num_open_ > 0) && !busy()) {
+ if (!current_elements_[cycle_index_]) {
+ // Try to create a new iterator from the next input element.
+ Status status = input_impl_->GetNext(
+ ctx.get(), &args_list_[cycle_index_], &end_of_input_);
+ if (!status.ok()) {
+ invocation_results_.emplace_back(new InvocationResult());
+ std::shared_ptr<InvocationResult>& result =
+ invocation_results_.back();
+ result->status.Update(status);
+ result->notification.Notify();
+ break;
+ }
+ if (!end_of_input_) {
+ Status status = MakeIteratorFromInputElement(
+ ctx.get(), args_list_[cycle_index_], cycle_index_,
+ dataset()->captured_func_.get(), prefix(),
+ &current_elements_[cycle_index_]);
if (!status.ok()) {
invocation_results_.emplace_back(new InvocationResult());
std::shared_ptr<InvocationResult>& result =
@@ -1486,39 +1489,25 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
result->notification.Notify();
break;
}
- if (!end_of_input_) {
- Status status = MakeIteratorFromInputElement(
- ctx.get(), args_list_[cycle_index_], cycle_index_,
- dataset()->captured_func_.get(), prefix(),
- &current_elements_[cycle_index_]);
- if (!status.ok()) {
- invocation_results_.emplace_back(new InvocationResult());
- std::shared_ptr<InvocationResult>& result =
- invocation_results_.back();
- result->status.Update(status);
- result->notification.Notify();
- break;
- }
- ++num_open_;
- }
+ ++num_open_;
}
- if (current_elements_[cycle_index_]) {
- // Pre-allocate invocation results for outputs to be fetched
- // and then fetch the outputs asynchronously.
- std::vector<std::shared_ptr<InvocationResult>> results;
- results.reserve(dataset()->block_length_);
- for (int i = 0; i < dataset()->block_length_; ++i) {
- invocation_results_.emplace_back(new InvocationResult());
- results.push_back(invocation_results_.back());
- }
- num_calls_++;
- element_in_use_[cycle_index_] = true;
- thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this,
- ctx, cycle_index_,
- std::move(results)));
+ }
+ if (current_elements_[cycle_index_]) {
+ // Pre-allocate invocation results for outputs to be fetched
+ // and then fetch the outputs asynchronously.
+ std::vector<std::shared_ptr<InvocationResult>> results;
+ results.reserve(dataset()->block_length_);
+ for (int i = 0; i < dataset()->block_length_; ++i) {
+ invocation_results_.emplace_back(new InvocationResult());
+ results.push_back(invocation_results_.back());
}
- cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
+ num_calls_++;
+ element_in_use_[cycle_index_] = true;
+ thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this,
+ ctx, cycle_index_,
+ std::move(results)));
}
+ cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
}
cond_var_.notify_all();
}
@@ -1621,6 +1610,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
// and there are elements left to be fetched.
condition_variable cond_var_;
+ // Identifies the maximum number of parallel calls.
+ std::atomic<int64> num_parallel_calls_;
+
// Iterator for input elements.
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index 0795987431..6abe6c8338 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -44,25 +44,17 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
int32 num_parallel_calls;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
- OP_REQUIRES(ctx, num_parallel_calls > 0,
+ OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments),
- use_inter_op_parallelism_, &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ use_inter_op_parallelism_,
+ &captured_func));
*output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_,
output_shapes_, use_inter_op_parallelism_,
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index 0b6e587881..8393024c51 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -14,17 +14,21 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/parallel_map_iterator.h"
+#include <atomic>
#include <deque>
#include <functional>
#include <utility>
#include <vector>
#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/cpu_info.h"
namespace tensorflow {
namespace data {
namespace {
+// TODO(b/116852688): Make coordination between the performance model and this
+// transformation more robust.
class ParallelMapIterator : public DatasetBaseIterator {
public:
explicit ParallelMapIterator(
@@ -39,11 +43,6 @@ class ParallelMapIterator : public DatasetBaseIterator {
num_parallel_calls_(num_parallel_calls) {}
~ParallelMapIterator() override {
- // TODO(mrry): Replace this cancellation logic with a
- // CancellationManager. The syntax would be more heavyweight,
- // but it would be possible to thread a cancellation manager
- // through the IteratorContext to upstream,
- // potentially-blocking iterators, when we add these.
mutex_lock l(mu_);
// Cancel the runner thread.
cancelled_ = true;
@@ -55,7 +54,17 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
Status Initialize(IteratorContext* ctx) override {
- SetMetadata(ctx, "parallelism", num_parallel_calls_);
+ mutex_lock l(mu_);
+ if (num_parallel_calls_ == kAutoTune) {
+ num_parallel_calls_ = 1;
+ // TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and
+ // use it here for the maximum.
+ AddTunableParameter(ctx, "parallelism", &num_parallel_calls_ /* value */,
+ 1 /* min */, port::NumSchedulableCPUs() /* max */,
+ &cond_var_);
+ } else {
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ }
TF_RETURN_IF_ERROR(
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
if (init_func_) {
@@ -71,17 +80,17 @@ class ParallelMapIterator : public DatasetBaseIterator {
mutex_lock l(mu_);
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty()) {
- StopWork(ctx);
+ RecordStop(ctx);
cond_var_.wait(l);
- StartWork(ctx);
+ RecordStart(ctx);
}
std::swap(result, invocation_results_.front());
invocation_results_.pop_front();
+ cond_var_.notify_all();
}
- cond_var_.notify_all();
- StopWork(ctx);
+ RecordStop(ctx);
result->notification.WaitForNotification();
- StartWork(ctx);
+ RecordStart(ctx);
return ProcessResult(result, out_tensors, end_of_sequence);
}
@@ -182,9 +191,9 @@ class ParallelMapIterator : public DatasetBaseIterator {
{
mutex_lock l(mu_);
num_calls_--;
+ cond_var_.notify_all();
}
result->notification.Notify();
- cond_var_.notify_all();
}
void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
@@ -199,9 +208,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
return;
}
- // Call `func_(input_element)`, store the result in
- // `result->return_values`, and notify `result->notification` to unblock
- // a consumer.
+ // Call `func_(input_element)`, store the result in `result->return_values`,
+ // and notify `result->notification` to unblock a consumer.
auto done = [this, result](Status status) {
result->status.Update(status);
CallCompleted(result);
@@ -211,8 +219,6 @@ class ParallelMapIterator : public DatasetBaseIterator {
std::move(done));
}
- int64 MaxInvocationResults() { return num_parallel_calls_; }
-
Status ProcessResult(const std::shared_ptr<InvocationResult>& result,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
@@ -232,31 +238,33 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
- StartWork(ctx.get());
- auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
std::vector<std::shared_ptr<InvocationResult>> new_calls;
new_calls.reserve(num_parallel_calls_);
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
+ int64 num_parallel_calls = num_parallel_calls_;
+ return num_calls_ >= num_parallel_calls ||
+ invocation_results_.size() >= num_parallel_calls;
+ };
while (true) {
{
mutex_lock l(mu_);
- while (!cancelled_ &&
- (num_calls_ >= num_parallel_calls_ ||
- invocation_results_.size() >= MaxInvocationResults())) {
- StopWork(ctx.get());
+ while (!cancelled_ && busy()) {
+ RecordStop(ctx.get());
cond_var_.wait(l);
- StartWork(ctx.get());
+ RecordStart(ctx.get());
}
if (cancelled_) {
return;
}
- while (num_calls_ < num_parallel_calls_ &&
- invocation_results_.size() < MaxInvocationResults()) {
+ while (!busy()) {
invocation_results_.emplace_back(new InvocationResult());
new_calls.push_back(invocation_results_.back());
num_calls_++;
}
+ cond_var_.notify_all();
}
- cond_var_.notify_all();
for (const auto& call : new_calls) {
CallFunction(ctx, call);
}
@@ -305,7 +313,6 @@ class ParallelMapIterator : public DatasetBaseIterator {
const DatasetBase* const input_dataset_; // Not owned.
const std::function<Status(IteratorContext*)> init_func_;
const ParallelMapIteratorFunction map_func_;
- const int32 num_parallel_calls_;
// Used for coordination between the main thread and the runner thread.
mutex mu_;
// Used for coordination between the main thread and the runner thread. In
@@ -314,6 +321,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
// parallelism and there are slots available in the `invocation_results_`
// buffer.
condition_variable cond_var_;
+ // Identifies the maximum number of parallel calls.
+ std::atomic<int64> num_parallel_calls_;
// Counts the number of outstanding calls.
int64 num_calls_ GUARDED_BY(mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
index 0cf5db017b..c28c06da62 100644
--- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
@@ -87,11 +87,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
"Expected len(dense_defaults) == len(dense_keys) but got: ",
dense_default_tensors.size(), " vs. ", dense_keys_.size()));
- std::vector<Tensor> dense_defaults;
- dense_defaults.reserve(dense_default_tensors.size());
- for (const Tensor& dense_default_t : dense_default_tensors) {
- dense_defaults.push_back(dense_default_t);
- }
+ std::vector<Tensor> dense_defaults(dense_default_tensors.begin(),
+ dense_default_tensors.end());
for (int d = 0; d < dense_keys_.size(); ++d) {
const Tensor& def_value = dense_defaults[d];
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 52c421caee..754ed772db 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -103,18 +103,18 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
+ auto stats_aggregator = ctx->stats_aggregator();
{
mutex_lock l(mu_);
- auto stats_aggregator = ctx->stats_aggregator();
TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
// Wait until the next element in the buffer has been
// produced, or we are shutting down.
while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
auto_tuner_.buffer_limit() != 0) {
auto_tuner_.RecordEmpty();
- StopWork(ctx);
+ RecordStop(ctx);
cond_var_.wait(l);
- StartWork(ctx);
+ RecordStart(ctx);
}
if (cancelled_) {
@@ -136,6 +136,14 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
mutex_lock parent_l(parent_mu_);
mutex_lock l(mu_);
+ if (stats_aggregator) {
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_size"),
+ static_cast<float>(buffer_.size()));
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_capacity"),
+ static_cast<float>(auto_tuner_.buffer_limit()));
+ }
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
}
@@ -219,6 +227,12 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
strings::StrCat(prefix_end_, "::buffer_utilization"),
{static_cast<float>(buffer_.size()) /
static_cast<float>(auto_tuner_.buffer_limit())});
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_size"),
+ static_cast<float>(buffer_.size()));
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_capacity"),
+ static_cast<float>(auto_tuner_.buffer_limit()));
}
// A new element is available. Forward the status from computing it, and
// (if we successfully got an element) the output values.
@@ -255,8 +269,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
//
// It owns the iterator context passed to it.
void PrefetchThread(const std::shared_ptr<IteratorContext>& ctx) {
- StartWork(ctx.get());
- auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
while (true) {
std::vector<Tensor> value;
@@ -264,9 +278,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
{
mutex_lock l(mu_);
while (!cancelled_ && buffer_.size() >= auto_tuner_.buffer_limit()) {
- StopWork(ctx.get());
+ RecordStop(ctx.get());
cond_var_.wait(l);
- StartWork(ctx.get());
+ RecordStart(ctx.get());
}
if (cancelled_) {
diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc
index 6e515d6cc8..2a911aa368 100644
--- a/tensorflow/core/kernels/data/scan_dataset_op.cc
+++ b/tensorflow/core/kernels/data/scan_dataset_op.cc
@@ -32,8 +32,7 @@ namespace {
class ScanDatasetOp : public UnaryDatasetOpKernel {
public:
explicit ScanDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tstate", &state_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
@@ -45,23 +44,12 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
OpInputList initial_state_inputs;
OP_REQUIRES_OK(ctx,
ctx->input_list("initial_state", &initial_state_inputs));
- std::vector<Tensor> initial_state;
- initial_state.reserve(initial_state_inputs.size());
- for (const Tensor& t : initial_state_inputs) {
- initial_state.push_back(t);
- }
-
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
+ std::vector<Tensor> initial_state(initial_state_inputs.begin(),
+ initial_state_inputs.end());
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output = new Dataset(ctx, input, func_, std::move(initial_state),
std::move(captured_func), state_types_, output_types_,
@@ -269,7 +257,6 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector state_types_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
index f5314f7a75..c8abfb9eb5 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
@@ -34,16 +34,18 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
&stats_aggregator_resource));
core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource);
- *output = new Dataset(ctx, input, stats_aggregator_resource);
+ *output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource);
}
private:
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const Tensor& resource_handle,
StatsAggregatorResource* stats_aggregator_resource)
: DatasetBase(DatasetContext(ctx)),
input_(input),
+ resource_handle_(resource_handle),
stats_aggregator_resource_(stats_aggregator_resource) {
input_->Ref();
stats_aggregator_resource_->Ref();
@@ -75,8 +77,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- return errors::Unimplemented("%s does not support serialization",
- DebugString());
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ Node* resource_handle_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {input_graph_node, resource_handle_node}, output));
+ return Status::OK();
}
private:
@@ -111,16 +118,14 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
- return Status::OK();
+ return errors::Unimplemented(dataset()->DebugString(),
+ " does not support checkpointing");
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
- return Status::OK();
+ return errors::Unimplemented(dataset()->DebugString(),
+ " does not support checkpointing");
}
private:
@@ -129,6 +134,7 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
};
const DatasetBase* const input_;
+ const Tensor resource_handle_;
StatsAggregatorResource* stats_aggregator_resource_;
};
};
diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc
index e1cefd23d8..ca4ea25b89 100644
--- a/tensorflow/core/kernels/data/tensor_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc
@@ -33,11 +33,7 @@ class TensorDatasetOp : public DatasetOpKernel {
OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs));
// TODO(mrry): Validate that the shapes of the "components" tensors match
// the "shapes" attr.;
- std::vector<Tensor> components;
- components.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- components.push_back(t);
- }
+ std::vector<Tensor> components(inputs.begin(), inputs.end());
*output = new Dataset(ctx, std::move(components));
}
diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc
index 3975086841..ac44623ce2 100644
--- a/tensorflow/core/kernels/data/window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/window_dataset_op.cc
@@ -33,22 +33,44 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
int64 window_size = 0;
- OP_REQUIRES_OK(
- ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size));
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "size", &window_size));
OP_REQUIRES(
ctx, window_size > 0,
errors::InvalidArgument("Window size must be greater than zero."));
- *output = new Dataset(ctx, window_size, input);
+ int64 window_shift = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int64>(ctx, "shift", &window_shift));
+ OP_REQUIRES(
+ ctx, window_shift > 0,
+ errors::InvalidArgument("Window shift must be greater than zero."));
+
+ int64 window_stride = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int64>(ctx, "stride", &window_stride));
+ OP_REQUIRES(
+ ctx, window_stride > 0,
+ errors::InvalidArgument("Window stride must be greater than zero."));
+
+ bool drop_remainder;
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<bool>(ctx, "drop_remainder", &drop_remainder));
+
+ *output = new Dataset(ctx, input, window_size, window_shift, window_stride,
+ drop_remainder);
}
private:
class Dataset : public DatasetBase {
public:
- Dataset(OpKernelContext* ctx, int64 window_size, const DatasetBase* input)
+ Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 window_size,
+ int64 window_shift, int64 window_stride, bool drop_remainder)
: DatasetBase(DatasetContext(ctx)),
+ input_(input),
window_size_(window_size),
- input_(input) {
+ window_shift_(window_shift),
+ window_stride_(window_stride),
+ drop_remainder_(drop_remainder) {
input_->Ref();
}
@@ -72,7 +94,8 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
}
string DebugString() const override {
- return strings::StrCat("WindowDatasetOp(", window_size_, ")::Dataset");
+ return strings::StrCat("WindowDatasetOp(", window_size_, window_shift_,
+ window_stride_, drop_remainder_, ")::Dataset");
}
protected:
@@ -81,10 +104,19 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
- Node* window_size = nullptr;
- TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size));
+ Node* window_size_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size_node));
+ Node* window_shift_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(window_shift_, &window_shift_node));
+ Node* window_stride_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(window_stride_, &window_stride_node));
+ Node* drop_remainder_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node));
TF_RETURN_IF_ERROR(
- b->AddDataset(this, {input_graph_node, window_size}, output));
+ b->AddDataset(this,
+ {input_graph_node, window_size_node, window_shift_node,
+ window_stride_node, drop_remainder_node},
+ output));
return Status::OK();
}
@@ -101,37 +133,79 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
- // Each row of `window_elements` is a tuple of tensors from the
- // input iterator.
+ const int64 window_size = dataset()->window_size_;
+ const int64 window_shift = dataset()->window_shift_;
+ const int64 window_stride = dataset()->window_stride_;
std::vector<std::vector<Tensor>> window_elements;
+ Status status = Status::OK();
{
mutex_lock l(mu_);
- if (!input_impl_) {
+ if (!input_impl_ && buffer_.empty()) {
*end_of_sequence = true;
return Status::OK();
}
- window_elements.reserve(dataset()->window_size_);
- *end_of_sequence = false;
- for (int i = 0; i < dataset()->window_size_ && !*end_of_sequence;
- ++i) {
- std::vector<Tensor> window_element_tuple;
- TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &window_element_tuple,
- end_of_sequence));
- if (!*end_of_sequence) {
- window_elements.emplace_back(std::move(window_element_tuple));
- } else {
- input_impl_.reset();
+
+ // Add elements to the buffer.
+ size_t target_size = TargetBufferSize(window_size, window_stride);
+ if (input_impl_) {
+ *end_of_sequence = false;
+ for (size_t i = buffer_.size();
+ i < target_size && !*end_of_sequence; ++i) {
+ std::vector<Tensor> element;
+ Status status =
+ input_impl_->GetNext(ctx, &element, end_of_sequence);
+ if (!*end_of_sequence) {
+ buffer_.emplace_back(std::move(element), status);
+ } else {
+ input_impl_.reset();
+ }
}
}
+
+ // If there are not enough elements and `drop_remainder` is set, we do
+ // not wish to return a smaller window.
+ if (buffer_.empty() ||
+ (dataset()->drop_remainder_ && buffer_.size() < target_size)) {
+ DCHECK(*end_of_sequence);
+ return Status::OK();
+ }
+
+ int num_elements = 1 + (buffer_.size() - 1) / window_stride;
+ window_elements.reserve(num_elements);
+ for (size_t i = 0; i < num_elements; ++i) {
+ status.Update(buffer_[window_stride * i].status);
+ if (!status.ok()) {
+ break;
+ }
+ window_elements.emplace_back(buffer_[window_stride * i].result);
+ }
+
+ // Shift the window, discarding elements if necessary.
+ int buffer_size = buffer_.size();
+ if (window_shift >= buffer_size) {
+ for (size_t i = buffer_size; input_impl_ && i < window_shift; ++i) {
+ bool end_of_input;
+ std::vector<Tensor> element;
+ // Ignore non-error status of discarded elements.
+ input_impl_->GetNext(ctx, &element, &end_of_input).IgnoreError();
+ if (end_of_input) {
+ input_impl_.reset();
+ }
+ }
+ buffer_.clear();
+ } else {
+ buffer_.erase(buffer_.begin(), buffer_.begin() + window_shift);
+ }
}
- if (window_elements.empty()) {
- DCHECK(*end_of_sequence);
- return Status::OK();
+ if (!status.ok()) {
+ return status;
}
+ // Construct output tensors.
const size_t num_tuple_components = window_elements[0].size();
const int64 num_window_elements = window_elements.size();
+ *end_of_sequence = false;
for (size_t idx = 0; idx < num_tuple_components; ++idx) {
DatasetBase* window_dataset;
std::vector<std::vector<Tensor>> window_component_elements;
@@ -154,7 +228,6 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(window_dataset,
&out_tensors->back()));
}
- *end_of_sequence = false;
return Status::OK();
}
@@ -167,6 +240,20 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
} else {
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
}
+ // Save buffer.
+ TF_RETURN_IF_ERROR(writer->WriteScalar(strings::StrCat("buffer_size"),
+ buffer_.size()));
+ for (int64 i = 0; i < buffer_.size(); i++) {
+ TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, buffer_[i].status));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(strings::StrCat("buffer[", i, "].size"),
+ buffer_[i].result.size()));
+ for (int64 j = 0; j < buffer_[i].result.size(); j++) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteTensor(strings::StrCat("buffer[", i, "][", j, "]"),
+ buffer_[i].result[j]));
+ }
+ }
return Status::OK();
}
@@ -178,22 +265,92 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
} else {
input_impl_.reset();
}
+ // Restore buffer.
+ int64 buffer_size;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(strings::StrCat("buffer_size"), &buffer_size));
+ buffer_.resize(buffer_size);
+ for (int64 i = 0; i < buffer_size; i++) {
+ int64 vector_size;
+ TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &buffer_[i].status));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ strings::StrCat("buffer[", i, "].size"), &vector_size));
+ buffer_[i].result.resize(vector_size);
+ for (int64 j = 0; j < vector_size; j++) {
+ TF_RETURN_IF_ERROR(
+ reader->ReadTensor(strings::StrCat("buffer[", i, "][", j, "]"),
+ &buffer_[i].result[j]));
+ }
+ }
return Status::OK();
}
private:
+ struct InvocationResult {
+ InvocationResult() = default;
+ InvocationResult(std::vector<Tensor>&& result, const Status& status)
+ : result(result), status(status) {}
+
+ std::vector<Tensor> result;
+ Status status;
+ };
+
+ Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
+ const Status& status)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ CodeKey(index), static_cast<int64>(status.code())));
+ if (!status.ok()) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
+ status.error_message()));
+ }
+ return Status::OK();
+ }
+
+ Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 code_int;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
+ error::Code code = static_cast<error::Code>(code_int);
+
+ if (code != error::Code::OK) {
+ string error_message;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(ErrorMessageKey(index), &error_message));
+ *status = Status(code, error_message);
+ } else {
+ *status = Status::OK();
+ }
+ return Status::OK();
+ }
+
+ string CodeKey(size_t index) {
+ return full_name(strings::StrCat("buffer[", index, "].code"));
+ }
+
+ string ErrorMessageKey(size_t index) {
+ return full_name(strings::StrCat("buffer[", index, "].error_message"));
+ }
+
+ size_t TargetBufferSize(int64 window_size, int64 window_stride) {
+ return (window_size - 1) * window_stride + 1;
+ }
+
mutex mu_;
+ std::deque<InvocationResult> buffer_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
};
- const int64 window_size_;
const DatasetBase* const input_;
+ const int64 window_size_;
+ const int64 window_shift_;
+ const int64 window_stride_;
+ const bool drop_remainder_;
};
};
REGISTER_KERNEL_BUILDER(Name("WindowDataset").Device(DEVICE_CPU),
WindowDatasetOp);
-
} // namespace
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index 2a25459194..76afd6f18c 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -17,7 +17,7 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/util_ptx.cuh"
+#include "third_party/cub/util_ptx.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/depthwise_conv_op.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
index 862a97723f..e7882acc80 100644
--- a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
@@ -35,10 +35,10 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "external/cub_archive/cub/device/device_radix_sort.cuh"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/iterator/constant_input_iterator.cuh"
-#include "external/cub_archive/cub/thread/thread_operators.cuh"
+#include "third_party/cub/device/device_radix_sort.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/iterator/constant_input_iterator.cuh"
+#include "third_party/cub/thread/thread_operators.cuh"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/core/kernels/eigen_cuboid_convolution.h b/tensorflow/core/kernels/eigen_cuboid_convolution.h
index c41fbc42d3..6a9a2accd8 100644
--- a/tensorflow/core/kernels/eigen_cuboid_convolution.h
+++ b/tensorflow/core/kernels/eigen_cuboid_convolution.h
@@ -40,8 +40,8 @@ namespace internal {
// at the given vertical and horizontal offsets.
//
// "Virtual matrix" dimensions:
-// *0: kernelChannels * kernelDepth * kernelRows * kernelCols;
-// 1: out_depth * out_height * out_width; * OTHERS (e.g batches, etc...)
+// *0: kernelChannels * kernelPlanes * kernelRows * kernelCols
+// 1: out_planes * out_height * out_width * OTHERS (e.g batches, etc...)
//
// *) extracted patches are continuous in memory (innermost dimension assuming
// col major layout)
@@ -113,6 +113,11 @@ class TensorContractionInputMapper<
m_num_patches = tensor.impl().dimensions()[NumDims - 5];
}
+ // Strides for navigating through the single patch.
+ m_patch_plane_stride = m_patch_depth;
+ m_patch_row_stride = m_patch_planes * m_patch_plane_stride;
+ m_patch_col_stride = m_patch_rows * m_patch_row_stride;
+
// Strides for the output tensor.
// IMPORTANT: These strides are used to locate an element in a patch at a
// depth zero (channel), which is not quite the same as "traditional"
@@ -166,6 +171,13 @@ class TensorContractionInputMapper<
m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
+ m_fastPatchPlaneStride =
+ internal::TensorIntDivisor<Index>(m_patch_plane_stride);
+ m_fastPatchRowStride =
+ internal::TensorIntDivisor<Index>(m_patch_row_stride);
+ m_fastPatchColStride =
+ internal::TensorIntDivisor<Index>(m_patch_col_stride);
+
m_fastInputPlaneStride =
internal::TensorIntDivisor<Index>(m_patch_plane_inflate_strides);
m_fastInputRowStride =
@@ -195,6 +207,10 @@ class TensorContractionInputMapper<
m_patch_cols = base_mapper.m_patch_cols;
m_num_patches = base_mapper.m_num_patches;
+ m_patch_plane_stride = base_mapper.m_patch_plane_stride;
+ m_patch_row_stride = base_mapper.m_patch_row_stride;
+ m_patch_col_stride = base_mapper.m_patch_col_stride;
+
m_rowStride = base_mapper.m_rowStride;
m_colStride = base_mapper.m_colStride;
m_patchStride = base_mapper.m_patchStride;
@@ -234,6 +250,9 @@ class TensorContractionInputMapper<
m_outputPlanesRows = base_mapper.m_outputPlanesRows;
m_fastNumPatches = base_mapper.m_fastNumPatches;
+ m_fastPatchPlaneStride = base_mapper.m_fastPatchPlaneStride;
+ m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
+ m_fastPatchColStride = base_mapper.m_fastPatchColStride;
m_fastInputPlaneStride = base_mapper.m_fastInputPlaneStride;
m_fastInputRowStride = base_mapper.m_fastInputRowStride;
m_fastInputColStride = base_mapper.m_fastInputColStride;
@@ -305,9 +324,9 @@ class TensorContractionInputMapper<
}
EIGEN_DEVICE_FUNC
- EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_patch_depth; }
+ EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_planeInputStride; }
EIGEN_DEVICE_FUNC
- EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_patch_planes; }
+ EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_rowStride; }
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; }
EIGEN_DEVICE_FUNC
@@ -391,14 +410,13 @@ class TensorContractionInputMapper<
const Index patchOffset = patchId / m_fastDimZero;
const Index colOffset = patchOffset / m_fastColStride;
- const Index inputCol = colIndex + colOffset;
-
const Index rowOffset =
(patchOffset - colOffset * m_colStride) / m_fastRowStride;
- const Index inputRow = rowIndex + rowOffset;
-
const Index planeOffset =
patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+
+ const Index inputCol = colIndex + colOffset;
+ const Index inputRow = rowIndex + rowOffset;
const Index inputPlane = planeIndex + planeOffset;
if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
@@ -524,12 +542,13 @@ class TensorContractionInputMapper<
eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
const Index colOffset = patchOffset / m_fastColStride;
- const Index inputCol = colIndex + colOffset;
const Index rowOffset =
(patchOffset - colOffset * m_colStride) / m_fastRowStride;
- const Index inputRow = rowIndex + rowOffset;
const Index planeOffset =
patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+
+ const Index inputCol = colIndex + colOffset;
+ const Index inputRow = rowIndex + rowOffset;
const Index inputPlane = planeIndex + planeOffset;
if (inputCol < 0 || inputRow < 0 || inputPlane < 0 ||
@@ -564,7 +583,7 @@ class TensorContractionInputMapper<
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
Index patchIndex, Index& planeIndex, Index& rowIndex, Index& colIndex,
Index& otherIndex) const {
- const int NumInputDims = array_size<
+ const size_t NumInputDims = array_size<
typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
// Check if patchIndex might contain batch and other dimensions.
@@ -594,7 +613,12 @@ class TensorContractionInputMapper<
Index m_patch_cols; // number of columns in the patch
Index m_num_patches; // number of patches to extract
- // Strides for the output tensor.
+ // Strides for navigating through the single patch.
+ Index m_patch_plane_stride;
+ Index m_patch_row_stride;
+ Index m_patch_col_stride;
+
+ // Strides for the output tensor (depth is not the part of the stride).
Index m_rowStride;
Index m_colStride;
Index m_patchStride;
@@ -637,6 +661,10 @@ class TensorContractionInputMapper<
// Fast representation of various divisors.
internal::TensorIntDivisor<Index> m_fastNumPatches;
+ internal::TensorIntDivisor<Index> m_fastPatchPlaneStride;
+ internal::TensorIntDivisor<Index> m_fastPatchRowStride;
+ internal::TensorIntDivisor<Index> m_fastPatchColStride;
+
internal::TensorIntDivisor<Index> m_fastInputPlaneStride;
internal::TensorIntDivisor<Index> m_fastInputRowStride;
internal::TensorIntDivisor<Index> m_fastInputColStride;
@@ -750,13 +778,62 @@ class TensorContractionSubMapper<
return m_base_mapper.nonStandardPatches();
}
+ // Max(Col|Row|Plane|Depth): compute the upper limit for the column, row,
+ // plane and depth index respectively that fits into the peeled_k elements
+ // starting at m_depth_offset.
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
+ const Index max_col =
+ fastPatchColStride().divide(m_depth_offset + peeled_k);
+ return std::min<Index>(1 + max_col, patchCols());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
+ const Index col) const {
+ const Index max_row = fastPatchRowStride().divide(
+ m_depth_offset + peeled_k - col * patchColStride());
+ return std::min<Index>(1 + max_row, patchRows());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxPlane(const Index peeled_k, const Index col,
+ const Index row) const {
+ const Index max_plane = fastPatchPlaneStride().divide(
+ m_depth_offset + peeled_k - col * patchColStride() -
+ row * patchRowStride());
+ return std::min<Index>(1 + max_plane, patchPlanes());
+ }
+
+ // MaxDepth uses only the remaining number of elements in the peeled_k.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
+ const Index start_depth) const {
+ return std::min<Index>(start_depth + num_elements, patchDepth());
+ }
+
+ // Every register matters in this code, so sometimes to prevent register
+ // spilling, instead of the variable that you would expect to see, we use
+ // another one, that is guaranteed to have the same value. E.g. patch depth is
+ // always the same as input depth, and it's also the same as input plane
+ // stride. Bunch of other parameters have similar relations.
+
+ typedef internal::TensorIntDivisor<Index> IndexDivisor;
+
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index patchDepth() const {
- return m_base_mapper.m_patch_depth;
+ eigen_assert(m_base_mapper.m_patch_depth ==
+ m_base_mapper.m_planeInputStride &&
+ "Patch depth must be equal to plane input stride.");
+ return m_base_mapper.m_planeInputStride;
}
+
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index patchPlanes() const {
- return m_base_mapper.m_patch_planes;
+ eigen_assert(m_base_mapper.m_patch_planes == m_base_mapper.m_rowStride &&
+ "Patch planes must be equal to row stride.");
+ return m_base_mapper.m_rowStride;
}
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index patchRows() const {
@@ -768,6 +845,36 @@ class TensorContractionSubMapper<
}
EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchPlaneStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride &&
+ "Patch depth must be equal to patch plane stride.");
+ return patchDepth();
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRowStride() const {
+ return m_base_mapper.m_patch_row_stride;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchColStride() const {
+ return m_base_mapper.m_patch_col_stride;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchPlaneStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride &&
+ "Patch depth must be equal to patch plane stride.");
+ return m_base_mapper.m_fastDimZero; // patch_depth
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const {
+ return m_base_mapper.m_fastPatchRowStride;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const {
+ return m_base_mapper.m_fastPatchColStride;
+ }
+
+ EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
const Index baseIndex) const {
const Index inputIndex = depth + baseIndex;
@@ -832,8 +939,7 @@ class TensorContractionSubMapper<
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index depthOffset() const {
- const Index patchOffset = m_depth_offset % m_base_mapper.patchDepth();
- return patchOffset;
+ return m_depth_offset % patchDepth();
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
@@ -859,24 +965,29 @@ class TensorContractionSubMapper<
// matrix" constructed from extracted volume patches) in contiguous memory.
//
// Given column major input (A0 beside A1 in memory):
-// A0 B0 C0 D0 E0 F0 G0 H0 ...
-// A1 B1 C1 D1 E1 F1 G1 H1 ...
-// A2 B2 C2 D2 E2 F2 G2 H2 ...
-// A3 B3 C3 D3 E3 F3 G3 H3 ...
-// A4 B4 C4 D4 E4 F4 G4 H4 ...
-// A5 B5 C5 D5 E5 F5 G5 H5 ...
-// A6 B6 C6 D6 E6 F6 G6 H6 ...
-// A7 B7 C7 D7 E7 F7 G7 H7 ...
+// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0
+// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1
+// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2
+// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3
+// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4
+// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5
+// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6
+// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7
// A8 ...
// ...
//
-// Packing yields row major output (A0 beside A1 in memory):
-// A0 A1 A2 A3 A4 A5 A6 A7
-// B0 B1 B2 B3 B4 B5 B6 B7
-// C0 ...
+// *) A, B, C, ... - patches extracted from the original input.
+// *) A0, A1, A2 ... - values from the same patch at different offsets.
+//
+// The traversal (packed rhs memory) order (B0 besides A0 in memory):
+// A0 B0 C0 D0 A1 B1 C1 D1 ...
+// E0 F0 G0 H0 E1 F1 G1 H1 ...
// ...
+// Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
+//
+// This traversal order must be the same as in default gemm_pack_rhs defined in
+// GeneralBlockPanelKernel.h.
//
-// *) A, B, C, ... - patches extracted from the original input.
// *) nr - number of registers along the 'n' dimension.
// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
// Multiplication" paper.
@@ -905,7 +1016,11 @@ struct gemm_pack_rhs<
nocontract_t, contract_t, packet_size, inner_dim_contiguous,
inner_dim_reordered, Alignment>
SubMapper;
+
typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -914,9 +1029,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
- typedef typename packet_traits<Scalar>::type Packet;
-
const Index packet_cols4 = (cols / 4) * 4;
const Index peeled_k = (depth / packet_size) * packet_size;
const bool non_standard_patches = rhs.nonStandardPatches();
@@ -929,81 +1041,58 @@ struct gemm_pack_rhs<
Index k = 0;
if ((packet_size % 4) == 0 && !non_standard_patches) {
- const Index patch_depth = rhs.patchDepth();
-
- if ((patch_depth % packet_size) == 0) {
- const Index patch_cols = rhs.patchCols();
- const Index patch_rows = rhs.patchRows();
- const Index patch_planes = rhs.patchPlanes();
-
- const Index startCol = rhs.colOffset();
- const Index max_cols = std::min<Index>(
- Eigen::divup(peeled_k, patch_rows * patch_planes * patch_depth) +
- startCol,
- patch_cols);
-
- for (Index c = startCol; c < max_cols; ++c) {
- eigen_assert(k < peeled_k);
-
- const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
- const Index max_rows = std::min<Index>(
- Eigen::divup(
- peeled_k - c * patch_rows * patch_planes * patch_depth,
- patch_planes * patch_depth) +
- startRow,
- patch_rows);
+ // FAST PATH:
+ // Iterate over patch columns, rows and planes if we know that a single
+ // packet do not span across multiple planes, rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
const bool pad_col0 = dm0.padCol(c);
const bool pad_col1 = dm1.padCol(c);
const bool pad_col2 = dm2.padCol(c);
const bool pad_col3 = dm3.padCol(c);
- for (Index r = startRow; r < max_rows; ++r) {
- eigen_assert(k < peeled_k);
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
- const Index startPlane =
- ((c == startCol) && (r == startRow)) ? rhs.planeOffset() : 0;
- const Index max_planes = std::min<Index>(
- Eigen::divup(
- peeled_k -
- c * patch_rows * patch_planes * patch_depth - // col
- r * patch_planes * patch_depth, // row
- patch_depth) +
- startPlane,
- patch_planes);
+ const Index start_plane = ((c == start_col) && (r == start_row))
+ ? rhs.planeOffset()
+ : 0;
+ const Index max_plane = rhs.maxPlane(peeled_k, c, r);
- const bool pad_row0 = dm0.padRow(r);
- const bool pad_row1 = dm1.padRow(r);
- const bool pad_row2 = dm2.padRow(r);
- const bool pad_row3 = dm3.padRow(r);
+ const bool pad_row0 = pad_col0 || dm0.padRow(r);
+ const bool pad_row1 = pad_col1 || dm1.padRow(r);
+ const bool pad_row2 = pad_col2 || dm2.padRow(r);
+ const bool pad_row3 = pad_col3 || dm3.padRow(r);
- for (Index p = startPlane; p < max_planes; ++p) {
- eigen_assert(k < peeled_k);
+ for (Index p = start_plane; p < max_plane; ++p) {
+ eigen_assert(k <= peeled_k);
- const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p);
- const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p);
- const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p);
- const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p);
+ const bool pad0 = pad_row0 || dm0.padPlane(p);
+ const bool pad1 = pad_row1 || dm1.padPlane(p);
+ const bool pad2 = pad_row2 || dm2.padPlane(p);
+ const bool pad3 = pad_row3 || dm3.padPlane(p);
const Index idx0 = dm0.baseIndex(p, r, c);
const Index idx1 = dm1.baseIndex(p, r, c);
const Index idx2 = dm2.baseIndex(p, r, c);
const Index idx3 = dm3.baseIndex(p, r, c);
- const Index startDepth =
- ((c == startCol) && (r == startRow) && (p == startPlane))
+ const Index start_depth =
+ ((c == start_col) && (r == start_row) && (p == start_plane))
? rhs.depthOffset()
: 0;
- const Index max_depth = std::min<Index>(
- peeled_k -
- c * patch_rows * patch_planes * patch_depth - // col
- r * patch_planes * patch_depth - // row
- p * patch_depth + // plane
- startDepth,
- patch_depth);
- eigen_assert((max_depth - startDepth) % packet_size == 0);
-
- for (Index d = startDepth; d < max_depth; d += packet_size) {
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
eigen_assert(k < peeled_k);
PacketBlock<Packet, 4> kernel;
kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
@@ -1026,20 +1115,12 @@ struct gemm_pack_rhs<
}
}
- for (; k < peeled_k; k += packet_size) {
- PacketBlock<Packet, 4> kernel;
- kernel.packet[0] = dm0.loadPacketFast(k);
- kernel.packet[1] = dm1.loadPacketFast(k);
- kernel.packet[2] = dm2.loadPacketFast(k);
- kernel.packet[3] = dm3.loadPacketFast(k);
- ptranspose(kernel);
- pstoreu(block + 0 * packet_size, kernel.packet[0]);
- pstoreu(block + 1 * packet_size, kernel.packet[1]);
- pstoreu(block + 2 * packet_size, kernel.packet[2]);
- pstoreu(block + 3 * packet_size, kernel.packet[3]);
- block += 4 * packet_size;
- }
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
} else {
+ // Packet can span multiple planes, rows or columns, so we have to go
+ // though the slower "standard" path.
for (; k < peeled_k; k += packet_size) {
PacketBlock<Packet, 4> kernel;
kernel.packet[0] = dm0.loadPacketStandard(k);
@@ -1055,7 +1136,9 @@ struct gemm_pack_rhs<
}
}
}
- if (!rhs.nonStandardPatches()) {
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
+ if (!non_standard_patches) {
for (; k < depth; k++) {
block[0] = dm0.loadCoeffStandard(k);
block[1] = dm1.loadCoeffStandard(k);
@@ -1074,7 +1157,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
@@ -1113,6 +1196,9 @@ struct gemm_pack_rhs<
inner_dim_reordered, Alignment>
SubMapper;
typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -1121,9 +1207,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
- typedef typename packet_traits<Scalar>::type Packet;
-
const int packet_size = 2;
const Index packet_cols4 = (cols / 4) * 4;
@@ -1138,56 +1221,39 @@ struct gemm_pack_rhs<
Index k = 0;
if (!non_standard_patches) {
- const Index patch_depth = rhs.patchDepth();
-
- if ((patch_depth % packet_size) == 0) {
- const Index patch_cols = rhs.patchCols();
- const Index patch_rows = rhs.patchRows();
- const Index patch_planes = rhs.patchPlanes();
-
- const Index startCol = rhs.colOffset();
- const Index max_cols = std::min<Index>(
- Eigen::divup(peeled_k, patch_rows * patch_planes * patch_depth) +
- startCol,
- patch_cols);
-
- for (Index c = startCol; c < max_cols; ++c) {
- eigen_assert(k < peeled_k);
-
- const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
- const Index max_rows = std::min<Index>(
- Eigen::divup(
- peeled_k - c * patch_rows * patch_planes * patch_depth,
- patch_planes * patch_depth) +
- startRow,
- patch_rows);
+ // FAST PATH:
+ // Iterate over patch columns, rows and planes if we know that a single
+ // packet do not span across multiple planes, rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
const bool pad_col0 = dm0.padCol(c);
const bool pad_col1 = dm1.padCol(c);
const bool pad_col2 = dm2.padCol(c);
const bool pad_col3 = dm3.padCol(c);
- for (Index r = startRow; r < max_rows; ++r) {
- eigen_assert(k < peeled_k);
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
- const Index startPlane =
- ((c == startCol) && (r == startRow)) ? rhs.planeOffset() : 0;
- const Index max_planes = std::min<Index>(
- Eigen::divup(
- peeled_k -
- c * patch_rows * patch_planes * patch_depth - // col
- r * patch_planes * patch_depth, // row
- patch_depth) +
- startPlane,
- patch_planes);
+ const Index start_plane = ((c == start_col) && (r == start_row))
+ ? rhs.planeOffset()
+ : 0;
+ const Index max_plane = rhs.maxPlane(peeled_k, c, r);
const bool pad_row0 = dm0.padRow(r);
const bool pad_row1 = dm1.padRow(r);
const bool pad_row2 = dm2.padRow(r);
const bool pad_row3 = dm3.padRow(r);
- for (Index p = startPlane; p < max_planes; ++p) {
- eigen_assert(k < peeled_k);
+ for (Index p = start_plane; p < max_plane; ++p) {
+ eigen_assert(k <= peeled_k);
const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p);
const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p);
@@ -1199,20 +1265,14 @@ struct gemm_pack_rhs<
const Index idx2 = dm2.baseIndex(p, r, c);
const Index idx3 = dm3.baseIndex(p, r, c);
- const Index startDepth =
- ((c == startCol) && (r == startRow) && (p == startPlane))
+ const Index start_depth =
+ ((c == start_col) && (r == start_row) && (p == start_plane))
? rhs.depthOffset()
: 0;
- const Index max_depth = std::min<Index>(
- peeled_k -
- c * patch_rows * patch_planes * patch_depth - // col
- r * patch_planes * patch_depth - // row
- p * patch_depth + // plane
- startDepth,
- patch_depth);
- eigen_assert((max_depth - startDepth) % packet_size == 0);
-
- for (Index d = startDepth; d < max_depth; d += packet_size) {
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
eigen_assert(k < peeled_k);
PacketBlock<Packet, 2> kernel0;
PacketBlock<Packet, 2> kernel1;
@@ -1237,21 +1297,9 @@ struct gemm_pack_rhs<
}
}
- for (; k < peeled_k; k += packet_size) {
- PacketBlock<Packet, 2> kernel0;
- PacketBlock<Packet, 2> kernel1;
- kernel0.packet[0] = dm0.loadPacketFast(k);
- kernel0.packet[1] = dm1.loadPacketFast(k);
- kernel1.packet[0] = dm2.loadPacketFast(k);
- kernel1.packet[1] = dm3.loadPacketFast(k);
- ptranspose(kernel0);
- ptranspose(kernel1);
- pstoreu(block + 0 * packet_size, kernel0.packet[0]);
- pstoreu(block + 1 * packet_size, kernel1.packet[0]);
- pstoreu(block + 2 * packet_size, kernel0.packet[1]);
- pstoreu(block + 3 * packet_size, kernel1.packet[1]);
- block += 4 * packet_size;
- }
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
} else {
for (; k < peeled_k; k += packet_size) {
PacketBlock<Packet, 2> kernel0;
@@ -1270,6 +1318,8 @@ struct gemm_pack_rhs<
}
}
}
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
if (!rhs.nonStandardPatches()) {
for (; k < depth; k++) {
block[0] = dm0.loadCoeffStandard(k);
@@ -1289,7 +1339,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
@@ -1328,6 +1378,8 @@ struct gemm_pack_rhs<
SubMapper;
typedef SubMapper DataMapper;
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
Index depth, Index cols, Index stride = 0,
@@ -1335,8 +1387,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
-
const Index packet_cols4 = (cols / 4) * 4;
for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
@@ -1364,7 +1414,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
@@ -1454,7 +1504,7 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
isColMajor ? kern.dimensions()[1] : kern.dimensions()[3];
// Spatial size of the kernel.
- const TensorIndex kernelDepth =
+ const TensorIndex kernelPlanes =
isColMajor ? kern.dimensions()[2] : kern.dimensions()[2];
const TensorIndex kernelRows =
isColMajor ? kern.dimensions()[3] : kern.dimensions()[1];
@@ -1474,27 +1524,27 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
const TensorIndex inputCols =
isColMajor ? in.dimension(3) : in.dimension(NumDims - 4);
- TensorIndex out_depth;
+ TensorIndex out_planes;
TensorIndex out_height;
TensorIndex out_width;
switch (padding_type) {
case PADDING_VALID:
- out_depth = Eigen::divup(inputPlanes - kernelDepth + 1,
- static_cast<TensorIndex>(stridePlanes));
+ out_planes = Eigen::divup(inputPlanes - kernelPlanes + 1,
+ static_cast<TensorIndex>(stridePlanes));
out_height = Eigen::divup(inputRows - kernelRows + 1,
static_cast<TensorIndex>(strideRows));
out_width = Eigen::divup(inputCols - kernelCols + 1,
static_cast<TensorIndex>(strideCols));
break;
case PADDING_SAME:
- out_depth =
+ out_planes =
Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes));
out_height =
Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows));
out_width = Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols));
break;
default:
- out_depth = 0;
+ out_planes = 0;
out_height = 0;
out_width = 0;
eigen_assert(false && "unexpected padding");
@@ -1503,9 +1553,9 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
DSizes<TensorIndex, 2> kernel_dims;
if (isColMajor) {
kernel_dims[0] = kernelFilters;
- kernel_dims[1] = kernelChannels * kernelDepth * kernelRows * kernelCols;
+ kernel_dims[1] = kernelChannels * kernelPlanes * kernelRows * kernelCols;
} else {
- kernel_dims[0] = kernelChannels * kernelDepth * kernelRows * kernelCols;
+ kernel_dims[0] = kernelChannels * kernelPlanes * kernelRows * kernelCols;
kernel_dims[1] = kernelFilters;
}
@@ -1516,15 +1566,15 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
DSizes<TensorIndex, 2> pre_contract_dims;
if (isColMajor) {
pre_contract_dims[0] =
- kernelChannels * kernelDepth * kernelRows * kernelCols;
- pre_contract_dims[1] = out_depth * out_height * out_width;
+ kernelChannels * kernelPlanes * kernelRows * kernelCols;
+ pre_contract_dims[1] = out_planes * out_height * out_width;
for (int i = 4; i < NumDims; ++i) {
pre_contract_dims[1] *= in.dimension(i);
}
} else {
pre_contract_dims[1] =
- kernelChannels * kernelDepth * kernelRows * kernelCols;
- pre_contract_dims[0] = out_depth * out_height * out_width;
+ kernelChannels * kernelPlanes * kernelRows * kernelCols;
+ pre_contract_dims[0] = out_planes * out_height * out_width;
for (int i = 0; i < NumDims - 4; ++i) {
pre_contract_dims[0] *= in.dimension(i);
}
@@ -1543,7 +1593,7 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
DSizes<TensorIndex, NumDims> post_contract_dims;
if (isColMajor) {
post_contract_dims[0] = kernelFilters;
- post_contract_dims[1] = out_depth;
+ post_contract_dims[1] = out_planes;
post_contract_dims[2] = out_height;
post_contract_dims[3] = out_width;
for (int i = 4; i < NumDims; ++i) {
@@ -1551,7 +1601,7 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
}
} else {
post_contract_dims[NumDims - 1] = kernelFilters;
- post_contract_dims[NumDims - 2] = out_depth;
+ post_contract_dims[NumDims - 2] = out_planes;
post_contract_dims[NumDims - 3] = out_height;
post_contract_dims[NumDims - 4] = out_width;
for (int i = 0; i < NumDims - 4; ++i) {
@@ -1564,13 +1614,13 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
kernel.reshape(kernel_dims)
.contract(input
.extract_volume_patches(
- kernelDepth, kernelRows, kernelCols, stridePlanes,
+ kernelPlanes, kernelRows, kernelCols, stridePlanes,
strideRows, strideCols, padding_type)
.reshape(pre_contract_dims),
contract_dims)
.reshape(post_contract_dims),
input
- .extract_volume_patches(kernelDepth, kernelRows, kernelCols,
+ .extract_volume_patches(kernelPlanes, kernelRows, kernelCols,
stridePlanes, strideRows, strideCols,
padding_type)
.reshape(pre_contract_dims)
diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions.h b/tensorflow/core/kernels/eigen_spatial_convolutions.h
index a4dff4b91c..e926d73f87 100644
--- a/tensorflow/core/kernels/eigen_spatial_convolutions.h
+++ b/tensorflow/core/kernels/eigen_spatial_convolutions.h
@@ -22,8 +22,36 @@ namespace Eigen {
namespace internal {
-// TODO: Consolidate this part of the code with the image patch extraction code
-// since they are both very similar.
+// WARNING: Most of the code here implicitly assumes that the matrix is in
+// ColMajor layout. This is guaranteed by the tensor contraction (see
+// TensorContraction.h).
+//
+// Inside Eigen a tensor contraction is represented by a matrix multiplication.
+// We don't want to actually extract image patches and reshape the result into
+// a matrix (this involves allocating huge extra memory), so the patch
+// extraction and reshape operations are implicit.
+//
+// TensorContractionInputMapper takes a matrix index and returns the coefficient
+// (or the packet) of the "virtual tensor", that would be at that index if we
+// were to actually reshape the result of patch extraction.
+//
+// TensorContractionSubMapper provides a similar view into the "virtual matrix"
+// at the given vertical and horizontal offsets.
+//
+// "Virtual matrix" dimensions:
+// *0: kernelChannels * kernelRows * kernelCols;
+// 1: out_height * out_width; * OTHERS (e.g batches, etc...)
+//
+// *) extracted patches are continuous in memory (innermost dimension assuming
+// col major layout)
+//
+// With this dimensions:
+// row - offset within a single patch (in code: patchId)
+// col - index of the extracted patch (in code: patchIndex)
+// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
+//
+// TODO(ezhulenev): Consolidate this part of the code with the image patch
+// extraction code since they are both very similar.
template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
typename ArgType, typename Device, typename Scalar_, typename Index,
typename nocontract_t, typename contract_t, int Side, int packet_size,
@@ -77,12 +105,17 @@ class TensorContractionInputMapper<
m_patch_cols = tensor.impl().dimensions()[2];
m_num_patches = tensor.impl().dimensions()[3];
} else {
- const int NumDims = tensor.impl().dimensions().size();
+ const size_t NumDims = tensor.impl().dimensions().size();
patch_depth = tensor.impl().dimensions()[NumDims - 1];
patch_rows = tensor.impl().dimensions()[NumDims - 2];
m_patch_cols = tensor.impl().dimensions()[NumDims - 3];
m_num_patches = tensor.impl().dimensions()[NumDims - 4];
}
+
+ // Strides for navigating through the single patch.
+ m_patch_row_stride = patch_depth;
+ m_patch_col_stride = patch_rows * m_patch_row_stride;
+
m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
m_patch_col_inflate_strides = tensor.impl().colInflateStride();
@@ -111,6 +144,10 @@ class TensorContractionInputMapper<
m_rowPaddingTop = tensor.impl().rowPaddingTop();
m_colPaddingLeft = tensor.impl().colPaddingLeft();
+ m_fastPatchRowStride =
+ internal::TensorIntDivisor<Index>(m_patch_row_stride);
+ m_fastPatchColStride =
+ internal::TensorIntDivisor<Index>(m_patch_col_stride);
m_fastInputRowStride =
internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
m_fastInputColStride =
@@ -126,6 +163,10 @@ class TensorContractionInputMapper<
: m_impl(base_mapper.m_impl) {
m_patch_cols = base_mapper.m_patch_cols;
m_num_patches = base_mapper.m_num_patches;
+
+ m_patch_row_stride = base_mapper.m_patch_row_stride;
+ m_patch_col_stride = base_mapper.m_patch_col_stride;
+
m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
@@ -148,6 +189,8 @@ class TensorContractionInputMapper<
m_rowPaddingTop = base_mapper.m_rowPaddingTop;
m_colPaddingLeft = base_mapper.m_colPaddingLeft;
+ m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
+ m_fastPatchColStride = base_mapper.m_fastPatchColStride;
m_fastInputRowStride = base_mapper.m_fastInputRowStride;
m_fastInputColStride = base_mapper.m_fastInputColStride;
m_fastNumPatches = base_mapper.m_fastNumPatches;
@@ -238,6 +281,8 @@ class TensorContractionInputMapper<
nocontract_t, contract_t, packet_size, inner_dim_contiguous,
inner_dim_reordered, Alignment>;
+ // Load coefficient from a patch specified by the "within patch offset"
+ // (patchId) and the precomputed indices of the first element of the patch.
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex,
Index colIndex, Index otherIndex) const {
@@ -250,6 +295,7 @@ class TensorContractionInputMapper<
(m_patch_col_inflate_strides == 1)
? inputCol
: ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
+
const Index rowOffset = patchOffset - colOffset * m_colStride;
const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
const Index origInputRow =
@@ -268,6 +314,8 @@ class TensorContractionInputMapper<
return m_impl.coeff(inputIndex);
}
+ // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
+ // and `in_strides` equal to 1 (template specialization without templates).
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex,
Index colIndex,
@@ -276,10 +324,9 @@ class TensorContractionInputMapper<
// Find the offset of the element wrt the location of the first element.
const Index patchOffset = patchId / m_fastDimZero;
-
const Index colOffset = patchOffset / m_fastColStride;
- const Index inputCol = colIndex + colOffset;
const Index rowOffset = patchOffset - colOffset * m_colStride;
+ const Index inputCol = colIndex + colOffset;
const Index inputRow = rowIndex + rowOffset;
if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
inputRow >= m_inputRows) {
@@ -291,6 +338,8 @@ class TensorContractionInputMapper<
return m_impl.coeff(inputIndex);
}
+ // Load packet from a patch specified by the "within patch offset"
+ // (patchId) and the precomputed indices of the first element of the patch.
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex,
Index colIndex,
@@ -318,12 +367,14 @@ class TensorContractionInputMapper<
if ((patchDepth() % packetSize) == 0) {
return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
} else {
+ // Offsets and input calculation here are identical to
+ // loadCoeffStandard(...), but repeated twice.
+
const Index patchOffsets[2] = {
patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
patchOffsets[1] / m_fastColStride};
-
const Index inputCols[2] = {colIndex + colOffsets[0],
colIndex + colOffsets[1]};
if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
@@ -371,8 +422,8 @@ class TensorContractionInputMapper<
eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
const Index colOffset = patchOffset / m_fastColStride;
- const Index inputCol = colIndex + colOffset;
const Index rowOffset = patchOffset - colOffset * m_colStride;
+ const Index inputCol = colIndex + colOffset;
const Index inputRow = rowIndex + rowOffset;
if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols ||
inputRow >= m_inputRows) {
@@ -401,7 +452,7 @@ class TensorContractionInputMapper<
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
Index patchIndex, Index& rowIndex, Index& colIndex,
Index& otherIndex) const {
- const int NumInputDims = array_size<
+ const size_t NumInputDims = array_size<
typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches;
const Index patch2DIndex = (NumInputDims == 3)
@@ -414,8 +465,15 @@ class TensorContractionInputMapper<
rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
}
- Index m_patch_cols; // number of colums in the patch
- Index m_num_patches; // number of patches to extract.
+ Index m_patch_cols; // number of columns in the patch
+ Index m_num_patches; // number of patches to extract.
+
+ // Strides for navigating through the single patch.
+ Index m_patch_row_stride;
+ Index m_patch_col_stride;
+ internal::TensorIntDivisor<Index> m_fastPatchRowStride;
+ internal::TensorIntDivisor<Index> m_fastPatchColStride;
+
Index m_patch_row_inflate_strides; // the strides for row inflation in the
// image patch
Index m_patch_col_inflate_strides; // the strides for col inflation in the
@@ -549,6 +607,40 @@ class TensorContractionSubMapper<
return m_base_mapper.nonStandardPatches();
}
+ // Max(Col|Row|Depth): compute the upper limit for the column, row and depth
+ // index respectively that fits into the peeled_k elements starting at
+ // m_depth_offset.
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
+ const Index max_col =
+ fastPatchColStride().divide(m_depth_offset + peeled_k);
+ return std::min<Index>(1 + max_col, patchCols());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
+ const Index col) const {
+ const Index max_row = fastPatchRowStride().divide(
+ m_depth_offset + peeled_k - col * patchColStride());
+ return std::min<Index>(1 + max_row, patchRows());
+ }
+
+ // MaxDepth uses only the remaining number of elements in the peeled_k.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
+ const Index start_depth) const {
+ return std::min<Index>(start_depth + num_elements, patchDepth());
+ }
+
+ // Every register matters in this code, so sometimes to prevent register
+ // spilling, instead of the variable that you would expect to see, we use
+ // another one, that is guaranteed to have the same value. E.g. patch depth is
+ // always the same as input depth, and it's also the same as input row stride.
+ // Bunch of other parameters have similar relations.
+
+ typedef internal::TensorIntDivisor<Index> IndexDivisor;
+
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index patchDepth() const {
return m_base_mapper.m_rowInputStride;
@@ -563,6 +655,28 @@ class TensorContractionSubMapper<
}
EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRowStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
+ "Patch depth must be equal to patch row stride.");
+ return patchDepth();
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchColStride() const {
+ return m_base_mapper.m_patch_col_stride;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
+ "Patch depth must be equal to patch row stride.");
+ return m_base_mapper.m_fastDimZero; // patch_depth
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const {
+ return m_base_mapper.m_fastPatchColStride;
+ }
+
+ EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
const Index baseIndex) const {
const Index inputIndex = depth + baseIndex;
@@ -603,8 +717,7 @@ class TensorContractionSubMapper<
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index depthOffset() const {
- const Index patchOffset = m_depth_offset % m_base_mapper.patchDepth();
- return patchOffset;
+ return m_depth_offset % patchDepth();
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
@@ -617,12 +730,44 @@ class TensorContractionSubMapper<
Index m_depth_offset; // First row in the input matrix
Index m_col_offset; // First col in the input matrix
- Index m_rowIndex; // precomputed row index corresponding to the col offset
- Index m_colIndex; // precomputed col index corresponding to the col offset
- Index
- m_otherIndex; // precomputed other index corresponding to the col offset
+ // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
+ // indices for the first element in a patch specified by col_offset
+ // (see computeBaseIndices(...) for details).
+ Index m_rowIndex;
+ Index m_colIndex;
+ Index m_otherIndex;
};
+// Arrange a block of the right input matrix (in our case it's always a "virtual
+// matrix" constructed from extracted image patches) in contiguous memory.
+//
+// Given column major input (A0 beside A1 in memory):
+// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0
+// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1
+// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2
+// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3
+// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4
+// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5
+// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6
+// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7
+// A8 ...
+// ...
+//
+// *) A, B, C, ... - patches extracted from the original input.
+// *) A0, A1, A2 ... - values from the same patch at different offsets.
+//
+// The traversal (packed rhs memory) order (B0 besides A0 in memory):
+// A0 B0 C0 D0 A1 B1 C1 D1 ...
+// E0 F0 G0 H0 E1 F1 G1 H1 ...
+// ...
+// Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
+//
+// This traversal order must be the same as in default gemm_pack_rhs defined in
+// GeneralBlockPanelKernel.h.
+//
+// *) nr - number of registers along the 'n' dimension.
+// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
+// Multiplication" paper.
template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
typename ArgType, typename Device, typename Scalar, typename Index,
typename nocontract_t, typename contract_t, int packet_size,
@@ -649,9 +794,9 @@ struct gemm_pack_rhs<
inner_dim_reordered, Alignment>
SubMapper;
typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
- EIGEN_DEVICE_FUNC
- static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -660,9 +805,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
- typedef typename packet_traits<Scalar>::type Packet;
-
const Index packet_cols4 = (cols / 4) * 4;
const Index peeled_k = (depth / packet_size) * packet_size;
const bool non_standard_patches = rhs.nonStandardPatches();
@@ -675,30 +817,27 @@ struct gemm_pack_rhs<
Index k = 0;
if ((packet_size % 4) == 0 && !non_standard_patches) {
- const Index patch_depth = rhs.patchDepth();
- if ((patch_depth % packet_size) == 0) {
- const Index patch_cols = rhs.patchCols();
- const Index patch_rows = rhs.patchRows();
-
- const Index startCol = rhs.colOffset();
- const Index max_cols = std::min<Index>(
- ceil_div(peeled_k, patch_rows * patch_depth) + startCol,
- patch_cols);
-
- for (Index c = startCol; c < max_cols; ++c) {
- eigen_assert(k < peeled_k);
- const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
- const Index max_rows = std::min<Index>(
- ceil_div(peeled_k - c * patch_rows * patch_depth, patch_depth) +
- startRow,
- patch_rows);
+ // FAST PATH:
+ // Iterate over patch columns and rows, if we know that a single
+ // packet do not span across multiple rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
const bool pad_col0 = dm0.padCol(c);
const bool pad_col1 = dm1.padCol(c);
const bool pad_col2 = dm2.padCol(c);
const bool pad_col3 = dm3.padCol(c);
- for (Index r = startRow; r < max_rows; ++r) {
- eigen_assert(k < peeled_k);
+
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
+
const bool pad0 = pad_col0 || dm0.padRow(r);
const bool pad1 = pad_col1 || dm1.padRow(r);
const bool pad2 = pad_col2 || dm2.padRow(r);
@@ -709,14 +848,13 @@ struct gemm_pack_rhs<
const Index idx2 = dm2.baseIndex(r, c);
const Index idx3 = dm3.baseIndex(r, c);
- const Index startDepth =
- ((c == startCol) && (r == startRow)) ? rhs.depthOffset() : 0;
- const Index max_depth =
- std::min<Index>(peeled_k - c * patch_rows * patch_depth -
- r * patch_depth + startDepth,
- patch_depth);
- eigen_assert((max_depth - startDepth) % packet_size == 0);
- for (Index d = startDepth; d < max_depth; d += packet_size) {
+ const Index start_depth = ((c == start_col) && (r == start_row))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
eigen_assert(k < peeled_k);
PacketBlock<Packet, 4> kernel;
kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
@@ -738,19 +876,9 @@ struct gemm_pack_rhs<
}
}
- for (; k < peeled_k; k += packet_size) {
- PacketBlock<Packet, 4> kernel;
- kernel.packet[0] = dm0.loadPacketFast(k);
- kernel.packet[1] = dm1.loadPacketFast(k);
- kernel.packet[2] = dm2.loadPacketFast(k);
- kernel.packet[3] = dm3.loadPacketFast(k);
- ptranspose(kernel);
- pstoreu(block + 0 * packet_size, kernel.packet[0]);
- pstoreu(block + 1 * packet_size, kernel.packet[1]);
- pstoreu(block + 2 * packet_size, kernel.packet[2]);
- pstoreu(block + 3 * packet_size, kernel.packet[3]);
- block += 4 * packet_size;
- }
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
} else {
for (; k < peeled_k; k += packet_size) {
PacketBlock<Packet, 4> kernel;
@@ -767,6 +895,8 @@ struct gemm_pack_rhs<
}
}
}
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
if (!rhs.nonStandardPatches()) {
for (; k < depth; k++) {
block[0] = dm0.loadCoeffStandard(k);
@@ -824,9 +954,9 @@ struct gemm_pack_rhs<
Alignment>
SubMapper;
typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
- EIGEN_DEVICE_FUNC
- static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -835,9 +965,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
- typedef typename packet_traits<Scalar>::type Packet;
-
const int packet_size = 2;
const Index packet_cols4 = (cols / 4) * 4;
const Index peeled_k = (depth / packet_size) * packet_size;
@@ -851,30 +978,27 @@ struct gemm_pack_rhs<
Index k = 0;
if (!non_standard_patches) {
- const Index patch_depth = rhs.patchDepth();
- if ((patch_depth % packet_size) == 0) {
- const Index patch_cols = rhs.patchCols();
- const Index patch_rows = rhs.patchRows();
-
- const Index startCol = rhs.colOffset();
- const Index max_cols = std::min<Index>(
- ceil_div(peeled_k, patch_rows * patch_depth) + startCol,
- patch_cols);
-
- for (Index c = startCol; c < max_cols; ++c) {
- eigen_assert(k < peeled_k);
- const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
- const Index max_rows = std::min<Index>(
- ceil_div(peeled_k - c * patch_rows * patch_depth, patch_depth) +
- startRow,
- patch_rows);
+ // FAST PATH:
+ // Iterate over patch columns and rows if we know that a single
+ // packet do not span across multiple rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
const bool pad_col0 = dm0.padCol(c);
const bool pad_col1 = dm1.padCol(c);
const bool pad_col2 = dm2.padCol(c);
const bool pad_col3 = dm3.padCol(c);
- for (Index r = startRow; r < max_rows; ++r) {
- eigen_assert(k < peeled_k);
+
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
+
const bool pad0 = pad_col0 || dm0.padRow(r);
const bool pad1 = pad_col1 || dm1.padRow(r);
const bool pad2 = pad_col2 || dm2.padRow(r);
@@ -885,14 +1009,13 @@ struct gemm_pack_rhs<
const Index idx2 = dm2.baseIndex(r, c);
const Index idx3 = dm3.baseIndex(r, c);
- const Index startDepth =
- ((c == startCol) && (r == startRow)) ? rhs.depthOffset() : 0;
- const Index max_depth =
- std::min<Index>(peeled_k - c * patch_rows * patch_depth -
- r * patch_depth + startDepth,
- patch_depth);
- eigen_assert((max_depth - startDepth) % packet_size == 0);
- for (Index d = startDepth; d < max_depth; d += packet_size) {
+ const Index start_depth = ((c == start_col) && (r == start_row))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
eigen_assert(k < peeled_k);
PacketBlock<Packet, 2> kernel0;
PacketBlock<Packet, 2> kernel1;
@@ -916,22 +1039,12 @@ struct gemm_pack_rhs<
}
}
- for (; k < peeled_k; k += packet_size) {
- PacketBlock<Packet, 2> kernel0;
- PacketBlock<Packet, 2> kernel1;
- kernel0.packet[0] = dm0.loadPacketFast(k);
- kernel0.packet[1] = dm1.loadPacketFast(k);
- kernel1.packet[0] = dm2.loadPacketFast(k);
- kernel1.packet[1] = dm3.loadPacketFast(k);
- ptranspose(kernel0);
- ptranspose(kernel1);
- pstoreu(block + 0 * packet_size, kernel0.packet[0]);
- pstoreu(block + 1 * packet_size, kernel1.packet[0]);
- pstoreu(block + 2 * packet_size, kernel0.packet[1]);
- pstoreu(block + 3 * packet_size, kernel1.packet[1]);
- block += 4 * packet_size;
- }
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
} else {
+ // Packet can span multiple rows or columns, so we have to go
+ // though the slower "standard" path.
for (; k < peeled_k; k += packet_size) {
PacketBlock<Packet, 2> kernel0;
PacketBlock<Packet, 2> kernel1;
@@ -949,7 +1062,9 @@ struct gemm_pack_rhs<
}
}
}
- if (!rhs.nonStandardPatches()) {
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
+ if (!non_standard_patches) {
for (; k < depth; k++) {
block[0] = dm0.loadCoeffStandard(k);
block[1] = dm1.loadCoeffStandard(k);
@@ -968,7 +1083,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
@@ -1006,8 +1121,7 @@ struct gemm_pack_rhs<
SubMapper;
typedef SubMapper DataMapper;
- EIGEN_DEVICE_FUNC
- static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -1016,8 +1130,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
-
const Index packet_cols4 = (cols / 4) * 4;
for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
@@ -1045,7 +1157,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
diff --git a/tensorflow/core/kernels/extract_volume_patches_op.cc b/tensorflow/core/kernels/extract_volume_patches_op.cc
new file mode 100644
index 0000000000..52cd078a35
--- /dev/null
+++ b/tensorflow/core/kernels/extract_volume_patches_op.cc
@@ -0,0 +1,197 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/*
+See extract_image_patches_op* files and docs for extract_image_patches in
+../ops/image_ops.cc.
+
+Rates are not supported as of now, but the comments hint how to edit the code
+when rates are to be added.
+*/
+
+#define USE_EIGEN_TENSOR
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/extract_volume_patches_op.h"
+#include <vector>
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+static inline void ParseAttributeVec5(OpKernelConstruction* context,
+ const string& attr_name,
+ std::vector<int32>* attr) {
+ OP_REQUIRES_OK(context, context->GetAttr(attr_name, attr));
+ OP_REQUIRES(
+ context, (*attr)[0] == 1 && (*attr)[4] == 1,
+ errors::Unimplemented("Only support ", attr_name, " across space."));
+ OP_REQUIRES(context, (*attr)[1] >= 1 && (*attr)[2] >= 1 && (*attr)[3] >= 1,
+ errors::OutOfRange(attr_name, " is out of range."));
+}
+
+template <typename Device, typename T>
+class ExtractVolumePatchesOp : public UnaryOp<T> {
+ public:
+ explicit ExtractVolumePatchesOp(OpKernelConstruction* context)
+ : UnaryOp<T>(context) {
+ ParseAttributeVec5(context, "ksizes", &ksizes_);
+ ParseAttributeVec5(context, "strides", &strides_);
+ // ParseAttributeVec5(context, "rates", &rates_);
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Input tensor is of the following dimensions:
+ // [ batch, in_planes, in_rows, in_cols, channels ]
+ const Tensor& input = context->input(0);
+ OP_REQUIRES(context, input.dims() == 5,
+ errors::InvalidArgument("input must be 5-dimensional",
+ input.shape().DebugString()));
+
+ const int batch = input.dim_size(0);
+ const int in_planes = input.dim_size(1);
+ const int in_rows = input.dim_size(2);
+ const int in_cols = input.dim_size(3);
+ const int depth = input.dim_size(4);
+
+ const int ksize_planes = ksizes_[1];
+ const int ksize_rows = ksizes_[2];
+ const int ksize_cols = ksizes_[3];
+
+ const int stride_planes = strides_[1];
+ const int stride_rows = strides_[2];
+ const int stride_cols = strides_[3];
+
+ /*
+ // TODO(hsgkim): enable rates
+ // Rates are disabled as of now due to Eigen's definitions of
+ // `extract_volume_patch` functions; none of them accept rates
+ // as its argument and rates are fixed to (1, 1, 1, 1, 1). A
+ // workaround has to be found for this.
+ // In order to enable rates, uncomment the following lines and use
+ // ksize_*_eff instead of ksize_* for the second argument of
+ // GetWindowedOutputSize calls.
+
+ const int rate_planes = rates_[1];
+ const int rate_rows = rates_[2];
+ const int rate_cols = rates_[3];
+
+ const int ksize_planes_eff = ksize_planes +
+ (ksize_planes - 1) * (rate_planes - 1);
+ const int ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
+ const int ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
+ */
+
+ int64 out_planes = 0, out_rows = 0, out_cols = 0;
+ int64 pad_planes = 0, pad_rows = 0, pad_cols = 0;
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(in_planes, ksize_planes, stride_planes,
+ padding_, &out_planes, &pad_planes));
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(in_rows, ksize_rows, stride_rows,
+ padding_, &out_rows, &pad_rows));
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(in_cols, ksize_cols, stride_cols,
+ padding_, &out_cols, &pad_cols));
+
+ const std::vector<int64> out_sizes = {
+ batch, out_planes, out_rows, out_cols,
+ ksize_planes * ksize_rows * ksize_cols * depth};
+ TensorShape out_shape(out_sizes);
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
+
+ // If there is nothing to compute, return.
+ if (out_shape.num_elements() == 0) {
+ return;
+ }
+
+ functor::ExtractVolumePatchesForward<Device, T>()(
+ context->eigen_device<Device>(), input.tensor<T, 5>(), ksize_planes,
+ ksize_rows, ksize_cols, stride_planes, stride_rows, stride_cols,
+ /* rate_planes, rate_rows, rate_cols, */
+ BrainPadding2EigenPadding(padding_), output->tensor<T, 5>());
+ }
+
+ private:
+ std::vector<int32> ksizes_;
+ std::vector<int32> strides_;
+ // std::vector<int32> rates_;
+
+ Padding padding_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ExtractVolumePatchesOp);
+};
+
+// Registration of the CPU implementations.
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ExtractVolumePatches").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ ExtractVolumePatchesOp<CPUDevice, T>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER);
+
+#undef REGISTER
+
+#if GOOGLE_CUDA
+
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+
+// clang-format off
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void ExtractVolumePatchesForward<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 5>::ConstTensor input, \
+ int patch_planes, int patch_rows, int patch_cols, \
+ int stride_planes, int stride_rows, int stride_cols, \
+ /* int rate_planes, int rate_rows, int rate_cols, */ \
+ const Eigen::PaddingType& padding, \
+ typename TTypes<T, 5>::Tensor output); \
+ extern template struct ExtractVolumePatchesForward<GPUDevice, T>;
+// clang-format on
+
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
+
+#undef DECLARE_GPU_SPEC
+
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ExtractVolumePatches").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
+ ExtractVolumePatchesOp<GPUDevice, T>);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER);
+
+#undef REGISTER
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/extract_volume_patches_op.h b/tensorflow/core/kernels/extract_volume_patches_op.h
new file mode 100644
index 0000000000..7e0502b770
--- /dev/null
+++ b/tensorflow/core/kernels/extract_volume_patches_op.h
@@ -0,0 +1,58 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
+#define TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
+
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_volume_patch.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T>
+struct ExtractVolumePatchesForward {
+ void operator()(const Device& d, typename TTypes<T, 5>::ConstTensor input,
+ int patch_planes, int patch_rows, int patch_cols,
+ int stride_planes, int stride_rows, int stride_cols,
+ /* int rate_planes, int rate_rows, int rate_cols, */
+ const Eigen::PaddingType& padding,
+ typename TTypes<T, 5>::Tensor output) {
+ const int64 N = std::max(input.size(), output.size());
+ if (N <= std::numeric_limits<Index32>::max()) {
+ auto output_32bit = To32Bit(output);
+ output_32bit.device(d) =
+ To32Bit(input)
+ .extract_volume_patches(patch_cols, patch_rows, patch_planes,
+ stride_cols, stride_rows, stride_planes,
+ padding)
+ .reshape(output_32bit.dimensions());
+ } else {
+ output.device(d) =
+ input
+ .extract_volume_patches(patch_cols, patch_rows, patch_planes,
+ stride_cols, stride_rows, stride_planes,
+ padding)
+ .reshape(output.dimensions());
+ }
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
diff --git a/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc b/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc
new file mode 100644
index 0000000000..c636493602
--- /dev/null
+++ b/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/extract_volume_patches_op.h"
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
+#define REGISTER(T) template struct ExtractVolumePatchesForward<GPUDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER);
+
+#undef REGISTER
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/fuzzing/BUILD b/tensorflow/core/kernels/fuzzing/BUILD
index 8bfa40304e..f2e0b2558f 100644
--- a/tensorflow/core/kernels/fuzzing/BUILD
+++ b/tensorflow/core/kernels/fuzzing/BUILD
@@ -43,4 +43,6 @@ tf_ops_fuzz_target_lib("example_proto_fast_parsing")
tf_ops_fuzz_target_lib("parse_tensor_op")
+tf_ops_fuzz_target_lib("decode_compressed")
+
tf_ops_fuzz_target_lib("decode_json_example")
diff --git a/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc b/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc
new file mode 100644
index 0000000000..0a56f4b63f
--- /dev/null
+++ b/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc
@@ -0,0 +1,45 @@
+/* Copyright 2018 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
+
+namespace tensorflow {
+namespace fuzzing {
+
+class FuzzDecodeCompressed : public FuzzStringInputOp {
+ void BuildGraph(const Scope& scope) override {
+ auto input =
+ tensorflow::ops::Placeholder(scope.WithOpName("input1"), DT_STRING);
+ auto d1 = tensorflow::ops::DecodeCompressed(
+ scope.WithOpName("d1"), input,
+ tensorflow::ops::DecodeCompressed::CompressionType(""));
+ auto d2 = tensorflow::ops::DecodeCompressed(
+ scope.WithOpName("d2"), input,
+ tensorflow::ops::DecodeCompressed::CompressionType("ZLIB"));
+ auto d3 = tensorflow::ops::DecodeCompressed(
+ scope.WithOpName("d3"), input,
+ tensorflow::ops::DecodeCompressed::CompressionType("GZIP"));
+ Scope grouper =
+ scope.WithControlDependencies(std::vector<tensorflow::Operation>{
+ d1.output.op(), d2.output.op(), d3.output.op()});
+ (void)tensorflow::ops::NoOp(grouper.WithOpName("output"));
+ }
+};
+
+STANDARD_TF_FUZZ_FUNCTION(FuzzDecodeCompressed);
+
+} // namespace fuzzing
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
index c90ad2cfeb..ada1235449 100644
--- a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
+++ b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
@@ -31,9 +31,37 @@ class FuzzParseTensor : public FuzzSession {
}
void FuzzImpl(const uint8_t* data, size_t size) final {
+ // We need to be sure that we don't request too many elements (i.e., we
+ // don't make ASAN OOM). In theory, a tensor shape can have arbitrary large
+ // number of elements, up to the limit of the memory available to the OS.
+ // However, due to the tracing done in ASAN, after 2^32 bytes of requested
+ // memory we would get a crash in the fuzzer (see b/34190148). Hence, let's
+ // try parsing the proto here, check that the size (if valid) is below a
+ // maximum threshold (using 2^20 for convenience), and then run the
+ // remainder of the fuzzer testing. Of course, this duplicates some work
+ // but it's better than repeating the investigation whenever Autofuzz
+ // detects another similar OOM.
+ string as_string = string(reinterpret_cast<const char*>(data), size);
+ TensorProto proto;
+ if (!ParseProtoUnlimited(&proto, as_string)) {
+ LOG(WARNING) << "Unable to parse proto of tensor\n";
+ return;
+ }
+ if (!TensorShape::IsValid(proto.tensor_shape())) {
+ LOG(WARNING) << "Invalid tensor shape\n";
+ return;
+ }
+ TensorShape shape(proto.tensor_shape());
+ const int64 num_elements = shape.num_elements();
+ const int64 max_num_elements = 1 << 20;
+ if (num_elements > max_num_elements) {
+ LOG(WARNING) << "Requiring a tensor with too many elements\n";
+ return;
+ }
+
+ // Now we can do the actual fuzz implementation
Tensor input_tensor(tensorflow::DT_STRING, TensorShape({}));
- input_tensor.scalar<string>()() =
- string(reinterpret_cast<const char*>(data), size);
+ input_tensor.scalar<string>()() = as_string;
// TODO(b/32704451): Don't just ignore the ::tensorflow::Status object!
RunOneInput(input_tensor).IgnoreError();
}
diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
index 277ee2be02..1c78de253e 100644
--- a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
+++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
@@ -114,7 +114,7 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator(
slice_size, Tindices, Tparams, Tout, &error_loc);
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
// Eigen implementation below is not highly performant. gather_nd_generator
// does not seem to be called in parallel, leading to very poor performance.
// Additionally, since it uses scalar (Tscratch) to invoke 'generate', it
@@ -126,12 +126,12 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
const Eigen::array<Eigen::DenseIndex, 1> loc{i};
gather_nd_generator(loc);
}
-#else // INTEL_MKL
+#else // INTEL_MKL && ENABLE_MKL
Tscratch.device(d) = Tscratch.reshape(reshape_dims)
.broadcast(broadcast_dims)
.generate(gather_nd_generator)
.sum();
-#endif
+#endif // INTEL_MKL && ENABLE_MKL
// error_loc() returns -1 if there's no out-of-bounds index,
// otherwise it returns the location of an OOB index in Tindices.
diff --git a/tensorflow/core/kernels/histogram_op_gpu.cu.cc b/tensorflow/core/kernels/histogram_op_gpu.cu.cc
index a88e9b0ddc..374a05850e 100644
--- a/tensorflow/core/kernels/histogram_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/histogram_op_gpu.cu.cc
@@ -18,7 +18,7 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_histogram.cuh"
+#include "third_party/cub/device/device_histogram.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc
index 6b6a14e9a7..1ded012f3c 100644
--- a/tensorflow/core/kernels/logging_ops.cc
+++ b/tensorflow/core/kernels/logging_ops.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <iostream>
+#include "absl/strings/str_split.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -90,6 +91,59 @@ class PrintOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("Print").Device(DEVICE_CPU), PrintOp);
+class PrintV2Op : public OpKernel {
+ public:
+ explicit PrintV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_stream", &output_stream_));
+
+ auto output_stream_index =
+ std::find(std::begin(valid_output_streams_),
+ std::end(valid_output_streams_), output_stream_);
+
+ if (output_stream_index == std::end(valid_output_streams_)) {
+ string error_msg = strings::StrCat(
+ "Unknown output stream: ", output_stream_, ", Valid streams are:");
+ for (auto valid_stream : valid_output_streams_) {
+ strings::StrAppend(&error_msg, " ", valid_stream);
+ }
+ OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* input_;
+ OP_REQUIRES_OK(ctx, ctx->input("input", &input_));
+ const string& msg = input_->scalar<string>()();
+
+ if (output_stream_ == "stdout") {
+ std::cout << msg << std::endl;
+ } else if (output_stream_ == "stderr") {
+ std::cerr << msg << std::endl;
+ } else if (output_stream_ == "log(info)") {
+ LOG(INFO) << msg << std::endl;
+ } else if (output_stream_ == "log(warning)") {
+ LOG(WARNING) << msg << std::endl;
+ } else if (output_stream_ == "log(error)") {
+ LOG(ERROR) << msg << std::endl;
+ } else {
+ string error_msg = strings::StrCat(
+ "Unknown output stream: ", output_stream_, ", Valid streams are:");
+ for (auto valid_stream : valid_output_streams_) {
+ strings::StrAppend(&error_msg, " ", valid_stream);
+ }
+ OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
+ }
+ }
+
+ const char* valid_output_streams_[6] = {"stdout", "stderr", "log(info)",
+ "log(warning)", "log(error)"};
+
+ private:
+ string output_stream_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("PrintV2").Device(DEVICE_CPU), PrintV2Op);
+
class TimestampOp : public OpKernel {
public:
explicit TimestampOp(OpKernelConstruction* context) : OpKernel(context) {}
diff --git a/tensorflow/core/kernels/logging_ops_test.cc b/tensorflow/core/kernels/logging_ops_test.cc
index 5e6958f364..a259d995fa 100644
--- a/tensorflow/core/kernels/logging_ops_test.cc
+++ b/tensorflow/core/kernels/logging_ops_test.cc
@@ -23,11 +23,33 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace {
+class PrintingV2GraphTest : public OpsTestBase {
+ protected:
+ Status Init(const string& output_stream = "log(warning)") {
+ TF_CHECK_OK(NodeDefBuilder("op", "PrintV2")
+ .Input(FakeInput(DT_STRING))
+ .Attr("output_stream", output_stream)
+ .Finalize(node_def()));
+ return InitOp();
+ }
+};
+
+TEST_F(PrintingV2GraphTest, StringSuccess) {
+ TF_ASSERT_OK(Init());
+ AddInputFromArray<string>(TensorShape({}), {"bar"});
+ TF_ASSERT_OK(RunOpKernel());
+}
+
+TEST_F(PrintingV2GraphTest, InvalidOutputStream) {
+ ASSERT_NE(::tensorflow::Status::OK(), (Init("invalid_output_stream")));
+}
+
class PrintingGraphTest : public OpsTestBase {
protected:
Status Init(DataType input_type1, DataType input_type2, string msg = "",
diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc
index 79967aab38..4ad390a411 100644
--- a/tensorflow/core/kernels/matmul_op.cc
+++ b/tensorflow/core/kernels/matmul_op.cc
@@ -578,7 +578,7 @@ struct MatMulFunctor<SYCLDevice, T> {
.Label("cublas"), \
MatMulOp<GPUDevice, T, true /* cublas */>)
-#if defined(INTEL_MKL)
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
// MKL does not support half, bfloat16 and int32 types for
// matrix-multiplication, so register the kernel to use default Eigen based
@@ -606,9 +606,9 @@ TF_CALL_double(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU_EIGEN);
TF_CALL_complex128(REGISTER_CPU_EIGEN);
TF_CALL_double(REGISTER_CPU_EIGEN);
-#endif
+#endif // INTEL_MKL_DNN_ONLY
-#else // INTEL MKL
+#else // INTEL_MKL && ENABLE_MKL
TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_half(REGISTER_CPU);
@@ -616,7 +616,7 @@ TF_CALL_bfloat16(REGISTER_CPU);
TF_CALL_int32(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
-#endif
+#endif // INTEL_MKL && ENABLE_MKL
#if GOOGLE_CUDA
TF_CALL_float(REGISTER_GPU);
diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
index 0841395dc3..bc135de11e 100644
--- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
@@ -223,10 +223,12 @@ class BatchMatMulMkl : public OpKernel {
Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
BatchMatMulMkl<CPUDevice, TYPE>)
+#ifdef ENABLE_MKL
TF_CALL_float(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_double(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL);
+#endif // ENABLE_MKL
} // end namespace tensorflow
#endif
diff --git a/tensorflow/core/kernels/mkl_conv_ops_test.cc b/tensorflow/core/kernels/mkl_conv_ops_test.cc
new file mode 100644
index 0000000000..a055351337
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_conv_ops_test.cc
@@ -0,0 +1,407 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session.h"
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#include "third_party/intel_mkl_dnn/include/mkldnn.h"
+#include "tensorflow/core/util/mkl_util.h"
+#endif
+
+// TODO(ezhulenev): Add numerical tests that will compare results of default
+// (aka Eigen) convolutions with MKL convolutions.
+
+// -------------------------------------------------------------------------- //
+// Performance Benchmarks. //
+// -------------------------------------------------------------------------- //
+
+// Compare performance of default Tensorflow convolution kernels (Eigen) with
+// MKL kernels on CPU.
+
+// Before running these benchmarks configure OpenMP environment variables:
+// export KMP_BLOCKTIME=0
+// export OMP_NUM_THREADS=${num_threads}
+
+namespace tensorflow {
+
+struct Conv2DDimensions {
+ Conv2DDimensions(int n, int h, int w, int c, int fc, int fh, int fw)
+ : input_batches(n),
+ input_height(h),
+ input_width(w),
+ input_depth(c),
+ filter_count(fc),
+ filter_height(fh),
+ filter_width(fw) {}
+
+ int input_batches;
+ int input_height;
+ int input_width;
+ int input_depth;
+ int filter_count;
+ int filter_height;
+ int filter_width;
+};
+
+static Tensor GetRandomTensor(const TensorShape& shape) {
+ Tensor tensor(DT_FLOAT, TensorShape(shape));
+ tensor.flat<float>() = tensor.flat<float>().setRandom();
+ return tensor;
+}
+
+// Get a random Tensor for the Conv2D input.
+static Tensor GetRandomInputTensor(const Conv2DDimensions& dims) {
+ return GetRandomTensor({dims.input_batches, dims.input_height,
+ dims.input_width, dims.input_depth});
+}
+
+// Get a random Tensor for the Conv2D filter.
+static Tensor GetRandomFilterTensor(const Conv2DDimensions& dims) {
+ return GetRandomTensor({dims.filter_height, dims.filter_width,
+ dims.input_depth, dims.filter_count});
+}
+
+// Get a random Tensor for the Conv2D output (assuming SAME padding).
+static Tensor GetRandomOutputTensor(const Conv2DDimensions& dims) {
+ return GetRandomTensor({dims.input_batches, dims.input_height,
+ dims.input_width, dims.filter_count});
+}
+
+// Get a Tensor encoding Conv2D input shape.
+static Tensor GetInputSizesTensor(const Conv2DDimensions& dims) {
+ return test::AsTensor<int32>({dims.input_batches, dims.input_height,
+ dims.input_width, dims.input_depth});
+}
+
+// Get a Tensor encoding Conv2D filter shape.
+static Tensor GetFilterSizesTensor(const Conv2DDimensions& dims) {
+ return test::AsTensor<int32>({dims.filter_height, dims.filter_width,
+ dims.input_depth, dims.filter_count});
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Tensor NonMklTensor() {
+ MklDnnShape non_mkl_shape;
+ non_mkl_shape.SetMklTensor(false);
+
+ auto size = static_cast<int64>(non_mkl_shape.GetSerializeBufferSize());
+ Tensor tensor(DT_UINT8, {size});
+
+ non_mkl_shape.SerializeMklDnnShape(tensor.flat<uint8>().data(),
+ size * sizeof(uint8));
+ return tensor;
+}
+#endif
+
+static Graph* DefaultConv2D(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_t = GetRandomInputTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ Node* filter = test::graph::Constant(graph, filter_t, "filter");
+
+ Node* conv2d;
+ TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d"), "Conv2D")
+ .Input(input)
+ .Input(filter)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Finalize(graph, &conv2d));
+
+ return graph;
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Graph* MklConv2D(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_t = GetRandomInputTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ Node* filter = test::graph::Constant(graph, filter_t, "filter");
+
+ Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl");
+
+ Node* conv2d;
+ TF_CHECK_OK(NodeBuilder(graph->NewName("mkl_conv_2d"), "_MklConv2D")
+ .Input(input)
+ .Input(filter)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Attr("_kernel", "MklOp")
+ .Finalize(graph, &conv2d));
+
+ return graph;
+}
+#endif
+
+static Graph* DefaultConv2DBwdInput(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_sizes_t = GetInputSizesTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+ Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding
+
+ Node* input_sizes =
+ test::graph::Constant(graph, input_sizes_t, "input_sizes");
+ Node* filter = test::graph::Constant(graph, filter_t, "filter");
+ Node* out_backprop =
+ test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+ Node* conv2d_bwd_input;
+ TF_CHECK_OK(
+ NodeBuilder(graph->NewName("conv_2d_bwd_input"), "Conv2DBackpropInput")
+ .Input(input_sizes)
+ .Input(filter)
+ .Input(out_backprop)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Finalize(graph, &conv2d_bwd_input));
+
+ return graph;
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Graph* MklConv2DBwdInput(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_sizes_t = GetInputSizesTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+ Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding
+
+ Node* input_sizes =
+ test::graph::Constant(graph, input_sizes_t, "input_sizes");
+ Node* filter = test::graph::Constant(graph, filter_t, "filter");
+ Node* out_backprop =
+ test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+ Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl");
+
+ Node* conv2d_bwd_input;
+ TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d_bwd_input"),
+ "_MklConv2DBackpropInput")
+ .Input(input_sizes)
+ .Input(filter)
+ .Input(out_backprop)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Attr("_kernel", "MklOp")
+ .Finalize(graph, &conv2d_bwd_input));
+
+ return graph;
+}
+#endif
+
+static Graph* DefaultConv2DBwdFilter(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_t = GetRandomInputTensor(dims);
+ Tensor filter_sizes_t = GetFilterSizesTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+ Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding
+
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ Node* filter_sizes =
+ test::graph::Constant(graph, filter_sizes_t, "filter_sizes");
+ Node* out_backprop =
+ test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+ Node* conv2d_bwd_filter;
+ TF_CHECK_OK(
+ NodeBuilder(graph->NewName("conv_2d_bwd_filter"), "Conv2DBackpropFilter")
+ .Input(input)
+ .Input(filter_sizes)
+ .Input(out_backprop)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Finalize(graph, &conv2d_bwd_filter));
+
+ return graph;
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Graph* MklConv2DBwdFilter(const Conv2DDimensions& dims) {
+ Graph* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_t = GetRandomInputTensor(dims);
+ Tensor filter_sizes_t = GetFilterSizesTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+ Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding
+
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ Node* filter_sizes =
+ test::graph::Constant(graph, filter_sizes_t, "filter_sizes");
+ Node* out_backprop =
+ test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+ Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl");
+
+ Node* conv2d_bwd_filter;
+ TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d_bwd_filter"),
+ "_MklConv2DBackpropFilter")
+ .Input(input)
+ .Input(filter_sizes)
+ .Input(out_backprop)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Attr("_kernel", "MklOp")
+ .Finalize(graph, &conv2d_bwd_filter));
+
+ return graph;
+}
+#endif
+
+// Macro arguments names: --------------------------------------------------- //
+// N: batch size
+// H: height
+// W: width
+// C: channels
+// FC: filter count
+// FH: filter height
+// FW: filter width
+
+#define BM_CONCAT(a, b) a##b
+
+#define BM_NAME(p, type, N, H, W, C, FC, FH, FW) \
+ BM_CONCAT(BM_##p##_##type##_in_##N##_##H##_##W##_##C, _f_##FC##_##FH##_##FW)
+
+// Flops computation in these benchmarks are the same as in
+// eigen_benchmark_cpu_test.cc.
+
+#define BM_Conv2DT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \
+ static void BM_NAME(Conv2D_##kind, type, N, H, W, C, FC, FH, \
+ FW)(int iters) { \
+ testing::SetLabel(LABEL); \
+ \
+ int64 num_computed_elements = (N) * (H) * (W) * (FC); \
+ int64 flops_per_iter = num_computed_elements * ((C) * (FH) * (FW)); \
+ testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \
+ \
+ Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \
+ test::Benchmark(#type, BM_CONCAT(kind, Conv2D)(dims)).Run(iters); \
+ } \
+ BENCHMARK(BM_NAME(Conv2D_##kind, type, N, H, W, C, FC, FH, FW))
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#define BM_Conv2D(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \
+ BM_Conv2DT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL);
+#else
+#define BM_Conv2D(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DT(Default, N, H, W, C, FC, FH, FW, type, LABEL);
+#endif
+
+#define BM_Conv2DBwdInputT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \
+ static void BM_NAME(Conv2DBwdInput_##kind, type, N, H, W, C, FC, FH, \
+ FW)(int iters) { \
+ testing::SetLabel(LABEL); \
+ \
+ int64 num_computed_elements = (N) * (H) * (W) * (C); \
+ int64 flops_per_iter = num_computed_elements * ((C) * (FH) * (FW)); \
+ testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \
+ \
+ Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \
+ test::Benchmark(#type, BM_CONCAT(kind, Conv2DBwdInput)(dims)).Run(iters); \
+ } \
+ BENCHMARK(BM_NAME(Conv2DBwdInput_##kind, type, N, H, W, C, FC, FH, FW))
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#define BM_Conv2DBwdInput(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DBwdInputT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \
+ BM_Conv2DBwdInputT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL);
+#else
+#define BM_Conv2DBwdInput(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DBwdInputT(Default, N, H, W, C, FC, FH, FW, type, LABEL);
+#endif
+
+#define BM_Conv2DBwdFilterT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \
+ static void BM_NAME(Conv2DBwdFilter_##kind, type, N, H, W, C, FC, FH, \
+ FW)(int iters) { \
+ testing::SetLabel(LABEL); \
+ \
+ int64 num_computed_elements = (FH) * (FW) * (C) * (FC); \
+ int64 flops_per_iter = num_computed_elements * ((N) * (H) * (W)); \
+ testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \
+ \
+ Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \
+ test::Benchmark(#type, BM_CONCAT(kind, Conv2DBwdFilter)(dims)).Run(iters); \
+ } \
+ BENCHMARK(BM_NAME(Conv2DBwdFilter_##kind, type, N, H, W, C, FC, FH, FW))
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#define BM_Conv2DBwdFilter(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DBwdFilterT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \
+ BM_Conv2DBwdFilterT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL);
+#else
+#define BM_Conv2DBwdFilter(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DBwdFilterT(Default, N, H, W, C, FC, FH, FW, type, LABEL);
+#endif
+
+// ImageNet Convolutions ---------------------------------------------------- //
+
+BM_Conv2D(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3");
+BM_Conv2D(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5");
+BM_Conv2D(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3");
+BM_Conv2D(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5");
+BM_Conv2D(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3");
+BM_Conv2D(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5");
+BM_Conv2D(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3");
+
+BM_Conv2DBwdInput(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3");
+BM_Conv2DBwdInput(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5");
+BM_Conv2DBwdInput(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3");
+BM_Conv2DBwdInput(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5");
+BM_Conv2DBwdInput(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3");
+BM_Conv2DBwdInput(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5");
+BM_Conv2DBwdInput(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3");
+
+BM_Conv2DBwdFilter(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3");
+BM_Conv2DBwdFilter(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5");
+BM_Conv2DBwdFilter(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3");
+BM_Conv2DBwdFilter(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5");
+BM_Conv2DBwdFilter(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3");
+BM_Conv2DBwdFilter(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5");
+BM_Conv2DBwdFilter(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc
index 077d62ce32..f4788f4851 100644
--- a/tensorflow/core/kernels/mkl_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_matmul_op.cc
@@ -217,7 +217,7 @@ class MklMatMulOp : public OpKernel {
reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta,
reinterpret_cast<MKL_Complex16*>(c), ldc);
}
-#endif
+#endif // !INTEL_MKL_DNN_ONLY
};
#define REGISTER_CPU(T) \
@@ -225,6 +225,7 @@ class MklMatMulOp : public OpKernel {
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
+#ifdef ENABLE_MKL
// TODO(inteltf) Consider template specialization when adding/removing
// additional types
TF_CALL_float(REGISTER_CPU);
@@ -233,7 +234,8 @@ TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
-#endif
+#endif // !INTEL_MKL_DNN_ONLY
+#endif // ENABLE_MKL
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_slice_op.cc b/tensorflow/core/kernels/mkl_slice_op.cc
new file mode 100644
index 0000000000..d63e14adf6
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_slice_op.cc
@@ -0,0 +1,358 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/array_ops.cc.
+
+#ifdef INTEL_MKL
+#ifndef INTEL_MKL_ML_ONLY
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/prefetch.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+#include "mkldnn.hpp"
+#include "tensorflow/core/util/mkl_util.h"
+
+using mkldnn::stream;
+using mkldnn::view;
+
+namespace tensorflow {
+
+namespace {
+
+gtl::InlinedVector<int64, 4> IntTensorToInt64Vec(const Tensor& tensor) {
+ gtl::InlinedVector<int64, 4> out;
+ if (tensor.dtype() == DT_INT32) {
+ for (int64 i = 0; i < tensor.NumElements(); ++i) {
+ out.push_back(tensor.flat<int32>()(i));
+ }
+ } else if (tensor.dtype() == DT_INT64) {
+ for (int64 i = 0; i < tensor.NumElements(); ++i) {
+ out.push_back(tensor.flat<int64>()(i));
+ }
+ } else {
+ // tensor must be either int32 or int64
+ DCHECK(false);
+ }
+ return out;
+}
+
+} // namespace
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+// A version of SharedValidation (slice_op.h) written for input that is in
+// either Mkl layout or Tensorflow layout.
+// A shared code to validate input shapes and check for identity, which is not dependent on the type of T.
+// We do this to reduce code size by not duplicating all this for all T (float, double, int32, etc.)
+static void ValidateMklInputs(OpKernelContext* context, bool* is_identity,
+ gtl::InlinedVector<int64, 4>* begin,
+ gtl::InlinedVector<int64, 4>* size) {
+ const int kInputTensorIndex = 0;
+ const int kInputBeginIndex = 1;
+ const int kInputSizeIndex = 2;
+ const Tensor& input = MklGetInput(context, kInputTensorIndex);
+ const Tensor& begin_tensor = MklGetInput(context, kInputBeginIndex);
+ const Tensor& size_tensor = MklGetInput(context, kInputSizeIndex);
+
+ MklDnnShape input_mkl_shape, begin_mkl_shape, size_mkl_shape;
+ GetMklShape(context, kInputTensorIndex, &input_mkl_shape);
+ GetMklShape(context, kInputBeginIndex, &begin_mkl_shape);
+ GetMklShape(context, kInputSizeIndex, &size_mkl_shape);
+
+ // Begin and size tensors cannot be in MklDnn layout.
+ DCHECK_EQ(begin_mkl_shape.IsMklTensor(), false);
+ DCHECK_EQ(size_mkl_shape.IsMklTensor(), false);
+
+ TensorShape input_tf_shape = input_mkl_shape.IsMklTensor()
+ ? input_mkl_shape.GetTfShape()
+ : input.shape();
+ const int input_dims = input_tf_shape.dims();
+
+ OP_REQUIRES(
+ context, context->op_kernel().IsLegacyVector(begin_tensor.shape()) &&
+ context->op_kernel().IsLegacyVector(size_tensor.shape()) &&
+ begin_tensor.NumElements() == input_dims &&
+ size_tensor.NumElements() == input_dims,
+ errors::InvalidArgument(
+ "Expected begin and size arguments to be 1-D tensors of size ",
+ input_dims, ", but got shapes ", begin_tensor.shape().DebugString(),
+ " and ", size_tensor.shape().DebugString(), " instead."));
+
+ *begin = IntTensorToInt64Vec(begin_tensor);
+ *size = IntTensorToInt64Vec(size_tensor);
+ for (int i = 0; i < input_dims; ++i) {
+ if ((*size)[i] == -1) {
+ // A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
+ (*size)[i] = input_tf_shape.dim_size(i) - (*begin)[i];
+ }
+ }
+
+ *is_identity = true;
+ for (int i = 0; i < input_dims; ++i) {
+ int64 b = (*begin)[i];
+ int64 s = (*size)[i];
+ if (input_tf_shape.dim_size(i) == 0) {
+ OP_REQUIRES(
+ context, b == 0 && s == 0,
+ errors::InvalidArgument("Expected begin[", i, "] == 0 (got ", b,
+ ") and size[", i, "] == 0 ", "(got ", s,
+ ") when ", "input.dim_size(", i, ") == 0"));
+ } else {
+ OP_REQUIRES(context, 0 <= b && b <= input_tf_shape.dim_size(i),
+ errors::InvalidArgument("Expected begin[", i, "] in [0, ",
+ input_tf_shape.dim_size(i),
+ "], but got ", b));
+ OP_REQUIRES(context, 0 <= s && b + s <= input_tf_shape.dim_size(i),
+ errors::InvalidArgument("Expected size[", i, "] in [0, ",
+ input_tf_shape.dim_size(i) - b,
+ "], but ", "got ", s));
+ }
+ const bool take_all = (b == 0) && (s == input_tf_shape.dim_size(i));
+ (*is_identity) &= take_all;
+ }
+}
+
+// A version of SharedSliceCommonCases function written for input tensor
+// that may be in MklDnn layout or in Tensorflow layout.
+template <typename T>
+static void CheckCommonCasesForMklInputs(OpKernelContext* context,
+ gtl::InlinedVector<int64, 4>* begin,
+ gtl::InlinedVector<int64, 4>* size,
+ bool* done) {
+ bool is_identity = true;
+ *done = false;
+
+ ValidateMklInputs(context, &is_identity, begin, size);
+ if (!context->status().ok()) return;
+
+ const Tensor& input = MklGetInput(context, 0);
+ MklDnnShape input_mkl_shape;
+ GetMklShape(context, 0, &input_mkl_shape);
+
+ if (is_identity) {
+ VLOG(1) << "Slice identity";
+ context->set_output(0, input);
+ // Mkl metadata tensor in this case can just be forwarded from input to
+ // output.
+ AllocateOutputSetMklShape(context, 0, input_mkl_shape);
+ *done = true;
+ }
+}
+
+// MKL-DNN implementation of Slice
+template <typename Device, typename T>
+class MklDnnSliceOp : public OpKernel {
+ public:
+ explicit MklDnnSliceOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ ~MklDnnSliceOp() {}
+
+ void Compute(OpKernelContext* context) override {
+ gtl::InlinedVector<int64, 4> begin;
+ gtl::InlinedVector<int64, 4> size;
+ bool done = false;
+
+ CheckCommonCasesForMklInputs<T>(context, &begin, &size, &done);
+ if (!context->status().ok() || done == true) return;
+
+ // Though MKL-DNN supports more than 8 dimension and
+ // less than 12 dimension tensor.
+ // But we are mimicking functionality of Eigen Slice op for CPU.
+ if (begin.size() >= 8) {
+ OP_REQUIRES(
+ context, false,
+ errors::Unimplemented("MklDnnSliceOp : Unhandled input dimensions"));
+ }
+
+ ComputeMklDnnSlice(context, begin, size);
+ }
+
+ private:
+ // Slice op implemented using MKL-DNN APIs.
+ void ComputeMklDnnSlice(OpKernelContext* context,
+ const gtl::InlinedVector<int64, 4>& begin,
+ const gtl::InlinedVector<int64, 4>& size) {
+ try {
+ // MKL-DNN API usage below is guided by description at:
+ // https://github.com/01org/mkl-dnn/issues/69
+ //
+ // Relevant part of the description is copied below:
+ //
+ // Let's say you want to copy a part of memory into another buffer (and
+ // probably change the format). Then your steps are:
+ //
+ // 1. create memory primitive descriptor in_mem_pd and memory primitive
+ // in_mem_p for the entire source data.
+ // 2. create view primitive descriptor in_submem_pd based on in_mem_pd,
+ // initial offsets, and sub-sizes
+ // 3. create memory primitive descriptor out_mem_pd and memory primitive
+ // out_mem_p for the output (the logical sizes should match sub-sizes
+ // used in step 2, but the format might be arbitrary)
+ // 4. create reorder primitive descriptor reorder_pd based on in_submem_pd
+ // and out_mem_pd
+ // 5. create reorder primitive itself based on reorder_pd, in_mem_p, and
+ // out_mem_p.
+ //
+ // Please notice that there is no view primitive. There is only view
+ // primitive descriptor. And the reorder uses source memory as input but
+ // traverses it according to a view in_submem_pd.
+
+ auto cpu_engine = engine(engine::cpu, 0);
+ MklDnnData<T> src(&cpu_engine);
+ MklDnnData<T> output(&cpu_engine);
+
+ // Populate offsets and sizes in memory::dims format based on vector.
+ memory::dims begin_dims = {};
+ begin_dims.resize(begin.size());
+ for (size_t i = 0; i < begin.size(); ++i) begin_dims[i] = begin[i];
+ memory::dims size_dims = {};
+ bool empty = false;
+ size_dims.resize(size.size());
+ for (size_t i = 0; i < size.size(); ++i) {
+ size_dims[i] = size[i];
+ if (size_dims[i] == 0) empty = true;
+ }
+
+ Tensor* output_tensor = nullptr;
+ MklDnnShape output_mkl_shape;
+
+ // If no dimension is selected in slice, the result should be empty.
+ // Just return an empty output tensor, and a dummy Mkl-shape tensor.
+ if (empty) { // for empty dims
+ auto shape_to = MklDnnDimsToTFShape(size_dims);
+ AllocateOutputSetMklShape(context, 0, &output_tensor, shape_to,
+ output_mkl_shape);
+ return;
+ }
+
+ // Step 1 (as per above description) - Create memory for user data.
+ // We use blocked format here to describe input tensor.
+ const Tensor& input_tensor = MklGetInput(context, 0);
+ MklDnnShape input_mkl_shape;
+ GetMklShape(context, 0, &input_mkl_shape);
+
+ if (input_mkl_shape.IsMklTensor()) {
+ auto input_mkl_format = input_mkl_shape.GetTfDataFormat();
+ auto input_tf_format = MklDnnDataFormatToTFDataFormat(input_mkl_format);
+ begin_dims = MklDnnDimsInNCHW(begin_dims, input_tf_format);
+ size_dims = MklDnnDimsInNCHW(size_dims, input_tf_format);
+ auto input_md = input_mkl_shape.GetMklLayout();
+ src.SetUsrMem(input_md, &input_tensor);
+ } else {
+ // Initialize input dimensions and strides to be used when input is not
+ // in MklDnn layout.
+ memory::dims input_dims, input_strides;
+ input_dims = TFShapeToMklDnnDims(input_tensor.shape());
+ input_strides = CalculateTFStrides(input_dims);
+ // Create input memory descriptor.
+ auto input_md =
+ MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
+ src.SetUsrMem(input_md, &input_tensor);
+ }
+
+ // Step 2 - create view primitive descriptor
+ auto view_pd =
+ view::primitive_desc(src.GetUsrMemPrimDesc(), size_dims, begin_dims)
+ .dst_primitive_desc();
+ auto output_strides = CalculateTFStrides(size_dims);
+ auto output_md =
+ MklDnnData<T>::CreateBlockedMemDesc(size_dims, output_strides);
+ auto output_pd = memory::primitive_desc(output_md, cpu_engine);
+
+ // Step 3 - Create memory for output. If input is in MklDnn layout, then
+ // output is also in MklDnn layout. Otherwise, output is in Tensorflow
+ // layout.
+ AllocateOutputTensor(context, input_mkl_shape, &output_pd, size_dims,
+ &output_tensor, &output_mkl_shape);
+ DCHECK(output_tensor);
+ DCHECK_EQ(input_mkl_shape.IsMklTensor(), output_mkl_shape.IsMklTensor());
+ output.SetUsrMem(output_md, output_tensor);
+
+ std::vector<primitive> net;
+ // Step 4 - create reorder primitive desc between view_pd and output_pd.
+ auto reorder_pd =
+ reorder::primitive_desc(view_pd, output.GetUsrMemPrimDesc());
+ // Step 5 - create reorder primitive itself.
+ net.push_back(reorder(reorder_pd, *src.GetUsrMem(), *output.GetUsrMem()));
+ // Execute the reorder primitive.
+ stream(stream::kind::eager).submit(net).wait();
+ } catch (mkldnn::error& e) {
+ string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
+ string(e.message) + ", in file " + string(__FILE__) +
+ ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ context,
+ errors::Aborted("Operation received an exception:", error_msg));
+ }
+ }
+
+ private:
+ void AllocateOutputTensor(OpKernelContext* context,
+ const MklDnnShape& input_mkl_shape,
+ memory::primitive_desc* output_pd,
+ const memory::dims& output_dims,
+ Tensor** output_tensor,
+ MklDnnShape* output_mkl_shape) {
+ DCHECK(output_tensor);
+ DCHECK(output_mkl_shape);
+
+ TensorShape output_tf_shape;
+
+ if (input_mkl_shape.IsMklTensor()) {
+ // Since input tensor is in Mkl layout, output tensor will be in Mkl
+ // layout.
+
+ // Allocate shape of Mkl tensor.
+ output_mkl_shape->SetMklTensor(true);
+ output_mkl_shape->SetMklLayout(output_pd);
+ output_mkl_shape->SetElemType(MklDnnType<T>());
+ output_mkl_shape->SetTfLayout(input_mkl_shape.GetDimension(), output_dims,
+ input_mkl_shape.GetTfDataFormat());
+
+ output_tf_shape.AddDim((output_pd->get_size() / sizeof(T)) + 1);
+ } else {
+ // If input is not in Mkl layout, then output won't be in Mkl layout.
+ output_mkl_shape->SetMklTensor(false);
+ output_tf_shape = MklDnnDimsToTFShape(output_dims);
+ }
+
+ AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape,
+ *output_mkl_shape);
+ }
+};
+
+// MKL-DNN Slice registration
+#define REGISTER_MKL_SLICE(type) \
+ REGISTER_KERNEL_BUILDER(Name("_MklSlice") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("begin") \
+ .HostMemory("size") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklDnnSliceOp<CPUDevice, type>);
+
+TF_CALL_float(REGISTER_MKL_SLICE);
+#undef REGISTER_MKL_SLICE
+
+} // namespace tensorflow
+
+#endif // INTEL_MKL_DNN
+#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/multinomial_op.cc b/tensorflow/core/kernels/multinomial_op.cc
index 7a64788448..82dfece4a2 100644
--- a/tensorflow/core/kernels/multinomial_op.cc
+++ b/tensorflow/core/kernels/multinomial_op.cc
@@ -75,7 +75,7 @@ struct MultinomialFunctor<CPUDevice, T, OutputType> {
// lambda. Since we want to let each worker have its own copy, we pass
// "gen" by reference and explicitly do a copy assignment here.
random::PhiloxRandom gen_copy = gen;
- // Skip takes units of 128 bytes. +3 is so rounding doesn't lead to
+ // Skip takes units of 128 bits. +3 is so rounding doesn't lead to
// us using the same state in different batches.
gen_copy.Skip(start_row * (num_samples + 3) / 4);
random::SimplePhilox simple_philox(&gen_copy);
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index fc1c9003aa..fdb4c84c46 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -97,7 +97,13 @@ class PartitionedCallOp : public AsyncOpKernel {
OP_REQUIRES_ASYNC(ctx, fbody != nullptr,
errors::Internal("Could not find handle ", handle),
done);
+ // We need to pass global op_registry as default_registry when creating
+ // graph. So that graph optimization passes can lookup all possible ops
+ // by name.
auto graph = tensorflow::MakeUnique<Graph>(fbody->graph->flib_def());
+ FunctionLibraryDefinition global_flib(OpRegistry::Global(), {});
+ TF_CHECK_OK(
+ graph.get()->AddFunctionLibrary(global_flib.ToProto()));
CopyGraph(*fbody->graph, graph.get());
OP_REQUIRES_OK_ASYNC(ctx, PinResourceArgs(graph.get(), args), done);
@@ -250,9 +256,11 @@ class PartitionedCallOp : public AsyncOpKernel {
VLOG(3) << "Partitioned function '" << func_.name() << "', yielding "
<< partitions.size() << " shards.";
- const FunctionLibraryDefinition* flib_def = &graph->flib_def();
for (const auto& partition : partitions) {
- std::unique_ptr<Graph> subgraph(new Graph(flib_def));
+ std::unique_ptr<Graph> subgraph(new Graph(graph->flib_def()));
+ FunctionLibraryDefinition global_flib(OpRegistry::Global(), {});
+ TF_CHECK_OK(
+ subgraph.get()->AddFunctionLibrary(global_flib.ToProto()));
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.expect_device_spec = true;
diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h
index 5fb1c92f94..272aa3b4f5 100644
--- a/tensorflow/core/kernels/queue_base.h
+++ b/tensorflow/core/kernels/queue_base.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <deque>
#include <vector>
+#include "absl/base/macros.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/queue_interface.h"
#include "tensorflow/core/framework/tensor.h"
@@ -82,6 +83,9 @@ class QueueBase : public QueueInterface {
// NOTE(mrry): This method is deprecated. Use
// `tensorflow::batch_util::CopySliceToElement()` defined in
// "./batch_util.h" instead.
+ ABSL_DEPRECATED(
+ "Use `tensorflow::batch_util::CopySliceToElement()` defined in "
+ "\"./batch_util.h\" instead.")
static Status CopyElementToSlice(const Tensor& element, Tensor* parent,
int64 index);
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
index e37232539f..04a53697c0 100644
--- a/tensorflow/core/kernels/random_op.cc
+++ b/tensorflow/core/kernels/random_op.cc
@@ -231,7 +231,13 @@ class RandomUniformIntOp : public OpKernel {
errors::InvalidArgument("maxval must be 0-D, got shape ",
maxval.shape().DebugString()));
- // Verify that minval < maxval
+ // Allocate output, and exit early if possible
+ Tensor* output;
+ OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
+ if (output->NumElements() == 0) return;
+
+ // Verify that minval < maxval. This check intentionally happens after the
+ // early exit for empty output. Zero impossible things are fine.
IntType lo = minval.scalar<IntType>()();
IntType hi = maxval.scalar<IntType>()();
OP_REQUIRES(
@@ -243,8 +249,6 @@ class RandomUniformIntOp : public OpKernel {
Distribution;
Distribution dist(lo, hi);
- Tensor* output;
- OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
auto output_flat = output->flat<IntType>();
functor::FillPhiloxRandom<Device, Distribution>()(
ctx, ctx->eigen_device<Device>(),
diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
index 88b3c2ac76..bb8254eaac 100644
--- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
+++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
@@ -21,11 +21,11 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/device/device_segmented_reduce.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
-#include "external/cub_archive/cub/warp/warp_reduce.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/device/device_segmented_reduce.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/warp/warp_reduce.cuh"
#include "cuda/include/cuComplex.h"
#include "tensorflow/core/kernels/reduction_ops.h"
#include "tensorflow/core/lib/core/bits.h"
diff --git a/tensorflow/core/kernels/reduction_ops_max.cc b/tensorflow/core/kernels/reduction_ops_max.cc
index 9cf953f4bf..8bfa44b2d0 100644
--- a/tensorflow/core/kernels/reduction_ops_max.cc
+++ b/tensorflow/core/kernels/reduction_ops_max.cc
@@ -50,6 +50,8 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int64, Eigen::internal::MaxReducer<type>>);
+
+REGISTER_GPU_KERNELS(Eigen::half);
REGISTER_GPU_KERNELS(float);
REGISTER_GPU_KERNELS(double);
REGISTER_GPU_KERNELS(int64);
diff --git a/tensorflow/core/kernels/reduction_ops_sum.cc b/tensorflow/core/kernels/reduction_ops_sum.cc
index e4ca89eca3..5318d8c133 100644
--- a/tensorflow/core/kernels/reduction_ops_sum.cc
+++ b/tensorflow/core/kernels/reduction_ops_sum.cc
@@ -76,15 +76,7 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("output")
.HostMemory("reduction_indices"),
ReductionOp<CPUDevice, int32, int64, Eigen::internal::SumReducer<int32>>);
-REGISTER_KERNEL_BUILDER(
- Name("Sum")
- .Device(DEVICE_GPU)
- .TypeConstraint<int64>("T")
- .TypeConstraint<int32>("Tidx")
- .HostMemory("input")
- .HostMemory("output")
- .HostMemory("reduction_indices"),
- ReductionOp<CPUDevice, int64, int32, Eigen::internal::SumReducer<int64>>);
+
#endif
#ifdef TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index 26705a8d34..427044ca67 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -51,7 +51,9 @@ limitations under the License.
#define EIGEN_USE_GPU
#endif
-#include "tensorflow/core/kernels/resource_variable_ops.h"
+#include <memory>
+#include <vector>
+
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
@@ -60,10 +62,12 @@ limitations under the License.
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/gather_functor.h"
+#include "tensorflow/core/kernels/resource_variable_ops.h"
#include "tensorflow/core/kernels/scatter_functor.h"
#include "tensorflow/core/kernels/training_op_helpers.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
@@ -72,6 +76,8 @@ limitations under the License.
namespace tensorflow {
REGISTER_RESOURCE_HANDLE_KERNEL(Var);
+REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp").Device(DEVICE_CPU),
+ ResourceHandlesOp<Var>);
ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
@@ -101,13 +107,58 @@ void ReadVariableOp::Compute(OpKernelContext* ctx) {
ctx->set_output(0, t);
}
+ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
+ int n;
+ OP_REQUIRES_OK(c, c->GetAttr("N", &n));
+ OP_REQUIRES_OK(c, c->GetAttr("dtypes", &dtypes_));
+ OP_REQUIRES(c, n == dtypes_.size(),
+ errors::InvalidArgument(
+ "Mismatched number of arguments to ReadVariablesOp (", n,
+ " vs. ", dtypes_.size(), ")"));
+}
+
+void ReadVariablesOp::Compute(OpKernelContext* ctx) {
+ std::vector<std::unique_ptr<Var, core::RefCountDeleter>> variables(
+ dtypes_.size());
+ std::vector<const ResourceHandle*> handles(dtypes_.size());
+ for (size_t i = 0; i < dtypes_.size(); ++i) {
+ handles[i] = &HandleFromInput(ctx, i);
+ }
+ const auto status = LookupResources(ctx, handles, &variables);
+ OP_REQUIRES(ctx, status.ok(),
+ errors::FailedPrecondition(
+ "Error while reading resource variable. This could mean that "
+ "the variable was uninitialized. ",
+ status.ToString()));
+
+ for (size_t i = 0; i < dtypes_.size(); ++i) {
+ // We're acquiring a reference to the underlying buffer while
+ // holding a shared lock to guarantee ordering of reads and
+ // writes.
+ tf_shared_lock ml(*variables[i]->mu());
+ const Tensor& t = *variables[i]->tensor();
+ OP_REQUIRES(ctx, dtypes_[i] == t.dtype(),
+ errors::InvalidArgument(
+ "Trying to read variable ", handles[i]->name(),
+ " from Container: ", handles[i]->container(),
+ " with wrong dtype. Expected ", DataTypeString(dtypes_[i]),
+ " got ", DataTypeString(t.dtype())));
+ ctx->set_output(i, t);
+ }
+}
+
REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU),
ReadVariableOp);
+REGISTER_KERNEL_BUILDER(Name("_ReadVariablesOp").Device(DEVICE_CPU),
+ ReadVariablesOp);
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(
Name("ReadVariableOp").Device(DEVICE_GPU).HostMemory("resource"),
ReadVariableOp);
+REGISTER_KERNEL_BUILDER(
+ Name("_ReadVariablesOp").Device(DEVICE_GPU).HostMemory("resources"),
+ ReadVariablesOp);
#define REGISTER_GPU_KERNELS(type) \
namespace functor { \
@@ -121,7 +172,12 @@ REGISTER_KERNEL_BUILDER(
.Device(DEVICE_GPU) \
.HostMemory("resource") \
.TypeConstraint<type>("dtype"), \
- ResourceHandleOp<Var>)
+ ResourceHandleOp<Var>) \
+ REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("resources") \
+ .TypeConstraint<type>("dtypes"), \
+ ResourceHandlesOp<Var>)
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_int64(REGISTER_GPU_KERNELS);
diff --git a/tensorflow/core/kernels/resource_variable_ops.h b/tensorflow/core/kernels/resource_variable_ops.h
index 9b60106f13..cffb732c38 100644
--- a/tensorflow/core/kernels/resource_variable_ops.h
+++ b/tensorflow/core/kernels/resource_variable_ops.h
@@ -28,6 +28,16 @@ class ReadVariableOp : public OpKernel {
DataType dtype_;
};
+class ReadVariablesOp : public OpKernel {
+ public:
+ explicit ReadVariablesOp(OpKernelConstruction* c);
+ void Compute(OpKernelContext* ctx) override;
+ bool IsExpensive() override { return false; }
+
+ private:
+ DataTypeVector dtypes_;
+};
+
class DestroyResourceOp : public OpKernel {
public:
explicit DestroyResourceOp(OpKernelConstruction* ctx);
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index e0194605ce..2f8aede427 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -145,6 +145,7 @@ class ScatterNdUpdateOp : public OpKernel {
if (dtype_ == DT_RESOURCE) {
Var* v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
+ core::ScopedUnref scoped_unref(v);
mutex_lock m(*v->mu());
DoCompute(c);
} else if (use_exclusive_lock_) {
diff --git a/tensorflow/core/kernels/searchsorted_op.cc b/tensorflow/core/kernels/searchsorted_op.cc
new file mode 100644
index 0000000000..dc627ac77a
--- /dev/null
+++ b/tensorflow/core/kernels/searchsorted_op.cc
@@ -0,0 +1,249 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/searchsorted_op.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+template <typename T, typename OutType>
+struct UpperBoundFunctor<CPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ // TODO(eriche): If anyone ever needs this to be faster, we can multithread.
+ for (int b = 0; b < batch_size; ++b) {
+ const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs;
+ OutType* output_ptr = output->data() + b * num_values;
+ for (int i = 0; i < num_values; ++i) {
+ output_ptr[i] =
+ std::upper_bound(sorted_inputs_ptr, sorted_inputs_ptr + num_inputs,
+ values(i + b * num_values)) -
+ sorted_inputs_ptr;
+ }
+ }
+
+ return Status::OK();
+ }
+};
+
+template <typename T, typename OutType>
+struct LowerBoundFunctor<CPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ // TODO(eriche): If anyone ever needs this to be faster, we can multithread.
+ for (int b = 0; b < batch_size; ++b) {
+ const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs;
+ OutType* output_ptr = output->data() + b * num_values;
+ for (int i = 0; i < num_values; ++i) {
+ output_ptr[i] =
+ std::lower_bound(sorted_inputs_ptr, sorted_inputs_ptr + num_inputs,
+ values(i + b * num_values)) -
+ sorted_inputs_ptr;
+ }
+ }
+
+ return Status::OK();
+ }
+};
+} // namespace functor
+
+template <typename Device, typename T, typename OutType>
+class UpperBoundOp : public OpKernel {
+ public:
+ explicit UpperBoundOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& sorted_inputs_t = ctx->input(0);
+ const Tensor& values_t = ctx->input(1);
+
+ // must have same batch dim_size for both
+ OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0),
+ Status(error::INVALID_ARGUMENT,
+ "Leading dim_size of both tensors must match."));
+
+ // this is required because we do indexing in int32 on the GPU
+ OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits<int>::max(),
+ Status(error::INVALID_ARGUMENT,
+ "values tensor size must less than INT_MAX"));
+
+ Tensor* output_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t));
+
+ if (output_t->dtype() == DT_INT32) {
+ OP_REQUIRES(ctx,
+ FastBoundsCheck(sorted_inputs_t.dim_size(1),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("trailing dim_size must less than "
+ "INT_MAX for int32 output type, was ",
+ sorted_inputs_t.dim_size(1)));
+ }
+
+ auto output = output_t->template flat<OutType>();
+ const auto sorted_inputs = sorted_inputs_t.template flat<T>();
+ const auto values = values_t.template flat<T>();
+ OP_REQUIRES_OK(
+ ctx, functor::UpperBoundFunctor<Device, T, OutType>::Compute(
+ ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0),
+ sorted_inputs_t.dim_size(1), values_t.dim_size(1), &output));
+ }
+};
+
+template <typename Device, typename T, typename OutType>
+class LowerBoundOp : public OpKernel {
+ public:
+ explicit LowerBoundOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& sorted_inputs_t = ctx->input(0);
+ const Tensor& values_t = ctx->input(1);
+
+ // must have same batch dim_size for both
+ OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0),
+ Status(error::INVALID_ARGUMENT,
+ "Leading dim_size of both tensors must match."));
+
+ // this is required because we do indexing in int32 on the GPU
+ OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits<int>::max(),
+ Status(error::INVALID_ARGUMENT,
+ "values tensor size must less than INT_MAX"));
+
+ Tensor* output_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t));
+
+ if (output_t->dtype() == DT_INT32) {
+ OP_REQUIRES(ctx,
+ FastBoundsCheck(sorted_inputs_t.dim_size(1),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("trailing dim_size must less than "
+ "INT_MAX for int32 output type, was ",
+ sorted_inputs_t.dim_size(1)));
+ }
+
+ auto output = output_t->template flat<OutType>();
+ const auto sorted_inputs = sorted_inputs_t.template flat<T>();
+ const auto values = values_t.template flat<T>();
+ OP_REQUIRES_OK(
+ ctx, functor::LowerBoundFunctor<Device, T, OutType>::Compute(
+ ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0),
+ sorted_inputs_t.dim_size(1), values_t.dim_size(1), &output));
+ }
+};
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ UpperBoundOp<CPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ UpperBoundOp<CPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#if GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ UpperBoundOp<GPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ UpperBoundOp<GPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#endif // GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ LowerBoundOp<CPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ LowerBoundOp<CPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#if GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ LowerBoundOp<GPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ LowerBoundOp<GPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#endif // GOOGLE_CUDA
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/searchsorted_op.h b/tensorflow/core/kernels/searchsorted_op.h
new file mode 100644
index 0000000000..f075bf0fa2
--- /dev/null
+++ b/tensorflow/core/kernels/searchsorted_op.h
@@ -0,0 +1,52 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T, typename OutType>
+struct UpperBoundFunctor {
+ // Searches for values in sorted_inputs and returns the greatest possible
+ // index where they maintain sorted order.
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output);
+};
+
+template <typename Device, typename T, typename OutType>
+struct LowerBoundFunctor {
+ // Searches for values in sorted_inputs and returns the lowest possible
+ // index where they maintain sorted order.
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output);
+};
+} // namespace functor
+
+} // end namespace tensorflow
+#endif // TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_
diff --git a/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc b/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc
new file mode 100644
index 0000000000..263b5bf298
--- /dev/null
+++ b/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc
@@ -0,0 +1,126 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/searchsorted_op.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace {
+template <typename T, typename OutType>
+__global__ void UpperBoundKernel(const T* sorted_inputs, int batch_size,
+ int sorted_inputs_size, int values_size,
+ const T* values, OutType* outputs) {
+ CUDA_1D_KERNEL_LOOP(work_unit_id, values_size * batch_size) {
+ int bid = work_unit_id / values_size;
+ T value = values[work_unit_id];
+ outputs[work_unit_id] = cuda_helper::upper_bound<T, OutType>(
+ sorted_inputs + bid * sorted_inputs_size, sorted_inputs_size, value);
+ }
+}
+
+template <typename T, typename OutType>
+__global__ void LowerBoundKernel(const T* sorted_inputs, int batch_size,
+ int sorted_inputs_size, int values_size,
+ const T* values, OutType* outputs) {
+ CUDA_1D_KERNEL_LOOP(work_unit_id, values_size * batch_size) {
+ int bid = work_unit_id / values_size;
+ T value = values[work_unit_id];
+ outputs[work_unit_id] = cuda_helper::lower_bound<T, OutType>(
+ sorted_inputs + bid * sorted_inputs_size, sorted_inputs_size, value);
+ }
+}
+} // namespace
+
+namespace functor {
+template <typename T, typename OutType>
+struct UpperBoundFunctor<GPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ const cudaStream_t& stream = GetCudaStream(context);
+ CudaLaunchConfig config =
+ GetCudaLaunchConfig(values.size(), context->eigen_gpu_device());
+
+ UpperBoundKernel<T>
+ <<<config.block_count, config.thread_per_block, 0, stream>>>(
+ sorted_inputs.data(), batch_size, num_inputs, num_values,
+ values.data(), output->data());
+
+ return Status::OK();
+ }
+};
+
+template <typename T, typename OutType>
+struct LowerBoundFunctor<GPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ const cudaStream_t& stream = GetCudaStream(context);
+ CudaLaunchConfig config =
+ GetCudaLaunchConfig(values.size(), context->eigen_gpu_device());
+
+ LowerBoundKernel<T>
+ <<<config.block_count, config.thread_per_block, 0, stream>>>(
+ sorted_inputs.data(), batch_size, num_inputs, num_values,
+ values.data(), output->data());
+
+ return Status::OK();
+ }
+};
+} // namespace functor
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::UpperBoundFunctor<GPUDevice, type, int32>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::UpperBoundFunctor<GPUDevice, type, int64>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::LowerBoundFunctor<GPUDevice, type, int32>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::LowerBoundFunctor<GPUDevice, type, int64>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc
index 77594479cb..97f77e45b6 100644
--- a/tensorflow/core/kernels/slice_op.cc
+++ b/tensorflow/core/kernels/slice_op.cc
@@ -411,7 +411,7 @@ class MklSliceOp : public OpKernel {
context->input(0).tensor<T, NDIM>(), indices, sizes);
}
};
-#endif
+#endif // INTEL_MKL
// Forward declarations of the functor specializations for declared in the
// sharded source files.
@@ -440,18 +440,14 @@ TF_CALL_ALL_TYPES(DECLARE_FOR_N);
#undef DECLARE_CPU_SPEC
} // namespace functor
-#ifndef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#define REGISTER_SLICE(type) \
REGISTER_KERNEL_BUILDER(Name("Slice") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.HostMemory("begin") \
.HostMemory("size"), \
- SliceOp<CPUDevice, type>)
-
-TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
-TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
-#undef REGISTER_SLICE
+ MklSliceOp<CPUDevice, type>)
#else
#define REGISTER_SLICE(type) \
REGISTER_KERNEL_BUILDER(Name("Slice") \
@@ -459,12 +455,12 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
.TypeConstraint<type>("T") \
.HostMemory("begin") \
.HostMemory("size"), \
- MklSliceOp<CPUDevice, type>)
+ SliceOp<CPUDevice, type>)
+#endif // INTEL_MKL && ENABLE_MKL
TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
#undef REGISTER_SLICE
-#endif // INTEL_MKL
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
diff --git a/tensorflow/core/kernels/split_lib_gpu.cu.cc b/tensorflow/core/kernels/split_lib_gpu.cu.cc
index 393818730b..a4a59dbcbc 100644
--- a/tensorflow/core/kernels/split_lib_gpu.cu.cc
+++ b/tensorflow/core/kernels/split_lib_gpu.cu.cc
@@ -54,6 +54,7 @@ void SplitCustom<Device, T>::operator()(
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
TF_CALL_complex64(DEFINE_GPU_KERNELS);
TF_CALL_complex128(DEFINE_GPU_KERNELS);
+TF_CALL_int64(DEFINE_GPU_KERNELS);
TF_CALL_bfloat16(DEFINE_GPU_KERNELS);
#undef DEFINE_GPU_KERNELS
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index 7b537fef5b..f0575de4d9 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -306,6 +306,7 @@ class StridedSliceAssignOp : public OpKernel {
Var* v;
OP_REQUIRES_OK(context,
LookupResource(context, HandleFromInput(context, 0), &v));
+ core::ScopedUnref scoped_unref(v);
mutex_lock ml(*v->mu());
OP_REQUIRES_OK(context,
PrepareToUpdateVariable<Device, T>(context, v->tensor()));
diff --git a/tensorflow/core/kernels/string_format_op.cc b/tensorflow/core/kernels/string_format_op.cc
new file mode 100644
index 0000000000..e4a1887f8d
--- /dev/null
+++ b/tensorflow/core/kernels/string_format_op.cc
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <iostream>
+#include "absl/strings/str_split.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+class StringFormatOp : public OpKernel {
+ public:
+ explicit StringFormatOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string template_;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("template", &template_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("placeholder", &placeholder_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("summarize", &summarize_));
+
+ split_template_ = absl::StrSplit(template_, placeholder_);
+ int64 num_placeholders = split_template_.size() - 1;
+ OP_REQUIRES(ctx, ctx->num_inputs() == num_placeholders,
+ errors::InvalidArgument(strings::StrCat(
+ "num placeholders in template and num inputs must match: ",
+ num_placeholders, " vs. ", ctx->num_inputs())));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ Tensor* formatted_string = nullptr;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &formatted_string));
+
+ string msg;
+ strings::StrAppend(&msg, split_template_[0].c_str());
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ strings::StrAppend(&msg, ctx->input(i).SummarizeValue(summarize_, true));
+ strings::StrAppend(&msg, split_template_[i + 1].c_str());
+ }
+
+ formatted_string->scalar<string>()() = msg;
+ }
+
+ private:
+ int32 summarize_ = 0;
+ string placeholder_;
+ std::vector<std::string> split_template_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("StringFormat").Device(DEVICE_CPU),
+ StringFormatOp);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/string_format_op_test.cc b/tensorflow/core/kernels/string_format_op_test.cc
new file mode 100644
index 0000000000..13130a5797
--- /dev/null
+++ b/tensorflow/core/kernels/string_format_op_test.cc
@@ -0,0 +1,66 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace {
+
+class StringFormatGraphTest : public OpsTestBase {
+ protected:
+ Status Init(int num_inputs, DataType input_type,
+ const string& template_ = "%s", const string& placeholder = "%s",
+ int summarize = 3) {
+ TF_CHECK_OK(NodeDefBuilder("op", "StringFormat")
+ .Input(FakeInput(num_inputs, input_type))
+ .Attr("template", template_)
+ .Attr("placeholder", placeholder)
+ .Attr("summarize", summarize)
+ .Finalize(node_def()));
+ return InitOp();
+ }
+};
+
+TEST_F(StringFormatGraphTest, Int32Success_7) {
+ TF_ASSERT_OK(Init(1, DT_INT32, "First tensor: %s"));
+
+ AddInputFromArray<int32>(TensorShape({7}), {1, 2, 3, 4, 5, 6, 7});
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_STRING, TensorShape({}));
+ test::FillValues<string>(&expected, {"First tensor: [1 2 3 ... 5 6 7]"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(StringFormatGraphTest, Int32Success_3_3) {
+ TF_ASSERT_OK(Init(1, DT_INT32, "First tensor: %s", "%s", 1));
+
+ AddInputFromArray<int32>(TensorShape({3, 3}), {1, 2, 3, 4, 5, 6, 7, 8, 9});
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_STRING, TensorShape({}));
+ test::FillValues<string>(&expected, {"First tensor: [[1 ... 3]\n ..."
+ "\n [7 ... 9]]"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+} // end namespace
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/string_length_op.cc b/tensorflow/core/kernels/string_length_op.cc
index a6829b29d9..435a7abdca 100644
--- a/tensorflow/core/kernels/string_length_op.cc
+++ b/tensorflow/core/kernels/string_length_op.cc
@@ -14,13 +14,18 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/string_util.h"
namespace tensorflow {
namespace {
class StringLengthOp : public OpKernel {
public:
- using OpKernel::OpKernel;
+ explicit StringLengthOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string unit;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit));
+ OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_));
+ }
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
@@ -32,10 +37,22 @@ class StringLengthOp : public OpKernel {
auto src = input.flat<string>();
auto dst = output->flat<int32>();
- for (int n = 0; n < src.size(); ++n) {
- dst(n) = src(n).size();
+ switch (unit_) {
+ case CharUnit::BYTE:
+ for (int n = 0; n < src.size(); ++n) {
+ dst(n) = src(n).size();
+ }
+ break;
+ case CharUnit::UTF8_CHAR:
+ for (int n = 0; n < src.size(); ++n) {
+ dst(n) = UTF8StrLen(src(n));
+ }
+ break;
}
}
+
+ private:
+ CharUnit unit_ = CharUnit::BYTE;
};
REGISTER_KERNEL_BUILDER(Name("StringLength").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/string_util.cc b/tensorflow/core/kernels/string_util.cc
new file mode 100644
index 0000000000..3a9803a052
--- /dev/null
+++ b/tensorflow/core/kernels/string_util.cc
@@ -0,0 +1,63 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/string_util.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace {
+inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; }
+} // namespace
+
+namespace tensorflow {
+
+// Sets unit value based on str.
+Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding) {
+ if (str == "UTF8") {
+ *encoding = UnicodeEncoding::UTF8;
+ } else {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid encoding \"", str, "\": Should be one of: BYTE"));
+ }
+ return Status::OK();
+}
+
+// Sets unit value based on str.
+Status ParseCharUnit(const string& str, CharUnit* unit) {
+ if (str == "BYTE") {
+ *unit = CharUnit::BYTE;
+ } else if (str == "UTF8_CHAR") {
+ *unit = CharUnit::UTF8_CHAR;
+ } else {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid unit \"", str, "\": Should be one of: BYTE, UTF8_CHAR"));
+ }
+ return Status::OK();
+}
+
+// Return the number of Unicode characters in a UTF-8 string.
+// Result may be incorrect if the input string is not valid UTF-8.
+int32 UTF8StrLen(const string& string) {
+ const int32 byte_size = string.size();
+ const char* const end = string.data() + byte_size;
+ const char* ptr = string.data();
+ int32 skipped_count = 0;
+ while (ptr < end) {
+ skipped_count += IsTrailByte(*ptr++) ? 1 : 0;
+ }
+ const int32 result = byte_size - skipped_count;
+ return result;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/string_util.h b/tensorflow/core/kernels/string_util.h
new file mode 100644
index 0000000000..390cf57702
--- /dev/null
+++ b/tensorflow/core/kernels/string_util.h
@@ -0,0 +1,45 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
+#define TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
+
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// Enumeration for unicode encodings. Used by ops such as
+// tf.strings.unicode_encode and tf.strings.unicode_decode.
+// TODO(edloper): Add support for:
+// UTF16, UTF32, UTF16BE, UTF32BE, UTF16LE, UTF32LE
+enum class UnicodeEncoding { UTF8 };
+
+// Enumeration for character units. Used by string such as
+// tf.strings.length and tf.substr.
+// TODO(edloper): Add support for: UTF32_CHAR, etc.
+enum class CharUnit { BYTE, UTF8_CHAR };
+
+// Sets `encoding` based on `str`.
+Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding);
+
+// Sets `unit` value based on `str`.
+Status ParseCharUnit(const string& str, CharUnit* unit);
+
+// Returns the number of Unicode characters in a UTF-8 string.
+// Result may be incorrect if the input string is not valid UTF-8.
+int32 UTF8StrLen(const string& string);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
diff --git a/tensorflow/core/kernels/tensor_array.cc b/tensorflow/core/kernels/tensor_array.cc
index 765467bc1e..0e6c0ddccc 100644
--- a/tensorflow/core/kernels/tensor_array.cc
+++ b/tensorflow/core/kernels/tensor_array.cc
@@ -62,7 +62,8 @@ TF_CALL_complex128(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
}
#define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T)
-TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU)
+TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU);
+TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU);
#undef TENSOR_ARRAY_SET_ZERO_CPU
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h
index e8dc4fad21..384a63e945 100644
--- a/tensorflow/core/kernels/tensor_array.h
+++ b/tensorflow/core/kernels/tensor_array.h
@@ -81,7 +81,8 @@ Status TensorSetZero(OpKernelContext* ctx, Tensor* value) {
Status TensorSetZero<Device, T>(OpKernelContext * ctx, Tensor * value);
#define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T)
-TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU)
+TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU);
+TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU);
#undef TENSOR_ARRAY_SET_ZERO_CPU
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index fe93b91eb8..a97a71b344 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -259,6 +259,7 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayV3").Device(DEVICE_CPU),
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
+TF_CALL_int64(REGISTER_GPU);
REGISTER_GPU(bfloat16);
#undef REGISTER_GPU
@@ -576,6 +577,7 @@ TF_CALL_ALL_TYPES(REGISTER_READ)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
+TF_CALL_int64(REGISTER_GPU);
REGISTER_GPU(bfloat16);
#undef REGISTER_GPU
@@ -1218,6 +1220,7 @@ TF_CALL_ALL_TYPES(REGISTER_SCATTER_AND_UNPACK);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
+TF_CALL_int64(REGISTER_GPU);
#undef REGISTER_GPU
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/topk_op_gpu.cu.cc b/tensorflow/core/kernels/topk_op_gpu.cu.cc
index ca296d5aa0..2fbe1fe7cb 100644
--- a/tensorflow/core/kernels/topk_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/topk_op_gpu.cu.cc
@@ -20,9 +20,9 @@ limitations under the License.
#include <cmath>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_segmented_radix_sort.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/device/device_segmented_radix_sort.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/training_op_helpers.cc b/tensorflow/core/kernels/training_op_helpers.cc
index d3c4f62071..4262a5404b 100644
--- a/tensorflow/core/kernels/training_op_helpers.cc
+++ b/tensorflow/core/kernels/training_op_helpers.cc
@@ -15,13 +15,16 @@ limitations under the License.
#include "tensorflow/core/kernels/training_op_helpers.h"
+#include "tensorflow/core/util/ptr_util.h"
+
namespace tensorflow {
-mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) {
+mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input,
+ Var** maybe_resource) {
+ *maybe_resource = nullptr;
if (ctx->input_dtype(input) == DT_RESOURCE) {
- Var* var;
- if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
- return var->mu();
+ if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) {
+ return (*maybe_resource)->mu();
} else {
ctx->CtxFailureWithWarning(
errors::Internal("Invalid variable reference."));
@@ -32,12 +35,13 @@ mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) {
}
// MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
-// in address order to mitigate deadlock. Returns a vector of acquired mutexes.
-// Safe to pass duplicates - will only lock each distinct mutex once. If
-// do_lock is false, returns immediately. Note that this silently doesn't lock
-// mutexes for invalid variable references; in all usages this is followed by
-// GetInputTensor which will signal a failure.
-std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
+// in address order to mitigate deadlock. Returns a structure that, when
+// deleted, will release the acquired mutexes. Safe to pass duplicates - will
+// only lock each distinct mutex once. If do_lock is false, returns
+// immediately. Note that this silently doesn't lock mutexes for invalid
+// variable references; in all usages this is followed by GetInputTensor which
+// will signal a failure.
+VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) {
bool any_resource = false;
for (auto i : input_ids) {
@@ -46,14 +50,16 @@ std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
break;
}
}
- std::vector<mutex_lock> locks;
if (!do_lock && !any_resource) {
- return locks;
+ return VariableInputLockHolder({}, {});
}
+ std::vector<Var*> vars;
std::vector<mutex*> mutexes;
std::vector<int> acquire_order;
for (auto input : input_ids) {
- mutex* mutex = GetTrainingVariableMutex(ctx, input);
+ Var* var;
+ mutex* mutex = GetTrainingVariableMutex(ctx, input, &var);
+ if (var) vars.push_back(var);
// Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
acquire_order.push_back(mutexes.size());
@@ -63,13 +69,19 @@ std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
std::sort(acquire_order.begin(), acquire_order.end(),
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
+ std::unique_ptr<std::vector<mutex_lock>> locks =
+ MakeUnique<std::vector<mutex_lock>>();
+ locks->reserve(acquire_order.size());
+
for (auto input : acquire_order) {
- mutex* mu = GetTrainingVariableMutex(ctx, input);
+ Var* var;
+ mutex* mu = GetTrainingVariableMutex(ctx, input, &var);
+ core::ScopedUnref scoped_unref(var);
if (mu != nullptr) {
- locks.emplace_back(*mu);
+ locks->emplace_back(*mu);
}
}
- return locks;
+ return VariableInputLockHolder(std::move(vars), std::move(locks));
}
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
diff --git a/tensorflow/core/kernels/training_op_helpers.h b/tensorflow/core/kernels/training_op_helpers.h
index 071cb371a7..9f173a80f7 100644
--- a/tensorflow/core/kernels/training_op_helpers.h
+++ b/tensorflow/core/kernels/training_op_helpers.h
@@ -23,9 +23,42 @@ limitations under the License.
namespace tensorflow {
-mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input);
+// Returns a borrowed pointer to the mutex for the variable `input` in `ctx`.
+//
+// If `input` corresponds to a `DT_RESOURCE`-type variable input,
+// `*maybe_resource` will be updated to contain the underlying resource, and the
+// caller will be responsible for calling `Unref()` on that resource.
+mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input,
+ Var** maybe_resource);
-std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
+// Utility structure that releases a sequence of borrowed mutexes when it is
+// deleted.
+struct VariableInputLockHolder {
+ public:
+ VariableInputLockHolder(std::vector<Var*> vars,
+ std::unique_ptr<std::vector<mutex_lock>> locks)
+ : vars_(std::move(vars)), locks_(std::move(locks)) {}
+
+ VariableInputLockHolder(VariableInputLockHolder&& other)
+ : vars_(std::move(other.vars_)), locks_(std::move(other.locks_)) {}
+
+ ~VariableInputLockHolder() {
+ // Release the locks before unreffing the Vars, because each lock
+ // is potentially borrowed from a Var in vars_.
+ locks_.reset();
+ for (Var* var : vars_) {
+ var->Unref();
+ }
+ }
+
+ private:
+ std::vector<Var*> vars_;
+ // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly,
+ // because a `std::vector<mutex_lock>` is not movable on all platforms.
+ std::unique_ptr<std::vector<mutex_lock>> locks_;
+};
+
+VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids);
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index 9a07ded17d..acf162deec 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -561,7 +561,9 @@ class ApplyAdadeltaOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
- mutex* mu = GetTrainingVariableMutex(ctx, 0);
+ Var* resource;
+ mutex* mu = GetTrainingVariableMutex(ctx, 0, &resource);
+ core::ScopedUnref scoped_unref(resource);
if (use_exclusive_lock_ && mu != nullptr) {
mutex_lock l1(*mu);
// Don't try to acquire a lock on the second ref as they share the same
@@ -710,7 +712,9 @@ class SparseApplyAdadeltaOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
- mutex* mu = GetTrainingVariableMutex(ctx, 0);
+ Var* var;
+ mutex* mu = GetTrainingVariableMutex(ctx, 0, &var);
+ core::ScopedUnref scoped_unref(var);
// mu_accum is actually the same mutex as mu_var since currently we use a
// global mutex.
//
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index 0f0f65c5a3..48e392c070 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -218,7 +218,7 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
perm, out);
}
-#if defined(INTEL_MKL)
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER(Name("Transpose") \
.Device(DEVICE_CPU) \
@@ -230,11 +230,8 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
.TypeConstraint<T>("T") \
.HostMemory("perm"), \
MklConjugateTransposeCpuOp);
-TF_CALL_ALL_TYPES(REGISTER);
-#undef REGISTER
-
-#else // INTEL_MKL
+#else // INTEL_MKL && ENABLE_MKL
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER(Name("Transpose") \
.Device(DEVICE_CPU) \
@@ -246,9 +243,10 @@ TF_CALL_ALL_TYPES(REGISTER);
.TypeConstraint<T>("T") \
.HostMemory("perm"), \
ConjugateTransposeCpuOp);
+#endif // INTEL_MKL && ENABLE_MKL
+
TF_CALL_ALL_TYPES(REGISTER)
#undef REGISTER
-#endif // INTEL_MKL
#if GOOGLE_CUDA
Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
diff --git a/tensorflow/core/kernels/unicode_script_op.cc b/tensorflow/core/kernels/unicode_script_op.cc
new file mode 100644
index 0000000000..085e397eba
--- /dev/null
+++ b/tensorflow/core/kernels/unicode_script_op.cc
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "unicode/errorcode.h" // TF:icu
+#include "unicode/uscript.h" // TF:icu
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+class UnicodeScriptOp : public OpKernel {
+ public:
+ explicit UnicodeScriptOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor* input_tensor;
+ OP_REQUIRES_OK(context, context->input("input", &input_tensor));
+ const auto& input_flat = input_tensor->flat<int32>();
+
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output("output", input_tensor->shape(),
+ &output_tensor));
+ auto output_flat = output_tensor->flat<int32>();
+
+ icu::ErrorCode status;
+ for (int i = 0; i < input_flat.size(); i++) {
+ UScriptCode script_code = uscript_getScript(input_flat(i), status);
+ if (status.isSuccess()) {
+ output_flat(i) = script_code;
+ } else {
+ output_flat(i) = -1;
+ status.reset();
+ }
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("UnicodeScript").Device(DEVICE_CPU),
+ UnicodeScriptOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/where_op_gpu.cu.h b/tensorflow/core/kernels/where_op_gpu.cu.h
index 8879d9dd4c..2255597651 100644
--- a/tensorflow/core/kernels/where_op_gpu.cu.h
+++ b/tensorflow/core/kernels/where_op_gpu.cu.h
@@ -21,10 +21,10 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/device/device_select.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/device/device_select.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/bounds_check.h"
diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc
index 99684ae47b..9ccd911b0e 100644
--- a/tensorflow/core/lib/core/threadpool.cc
+++ b/tensorflow/core/lib/core/threadpool.cc
@@ -17,6 +17,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/platform/context.h"
#include "tensorflow/core/platform/denormal.h"
#include "tensorflow/core/platform/logging.h"
@@ -120,6 +121,54 @@ void ThreadPool::Schedule(std::function<void()> fn) {
impl_->Schedule(std::move(fn));
}
+int ThreadPool::NumShardsUsedByTransformRangeConcurrently(
+ const int64 block_size, const int64 total) {
+ if (block_size <= 0 || total <= 1 || total <= block_size ||
+ NumThreads() == 1) {
+ return 1;
+ }
+ return (total + block_size - 1) / block_size;
+}
+
+// This functionality is similar to parallelFor, except that reasoning about
+// the number of shards used is significantly easier.
+void ThreadPool::TransformRangeConcurrently(
+ const int64 block_size, const int64 total,
+ const std::function<void(int64, int64)>& fn) {
+ const int num_shards_used =
+ NumShardsUsedByTransformRangeConcurrently(block_size, total);
+ if (num_shards_used == 1) {
+ fn(0, total);
+ return;
+ }
+
+ // Adapted from Eigen's parallelFor implementation.
+ BlockingCounter counter(num_shards_used);
+ std::function<void(int64, int64)> handle_range =
+ [=, &handle_range, &counter, &fn](int64 first, int64 last) {
+ while (last - first > block_size) {
+ // Find something near the midpoint which is a multiple of block size.
+ const int64 mid = first + ((last - first) / 2 + block_size - 1) /
+ block_size * block_size;
+ Schedule([=, &handle_range]() { handle_range(mid, last); });
+ last = mid;
+ }
+ // Single block or less, execute directly.
+ fn(first, last);
+ counter.DecrementCount(); // The shard is done.
+ };
+ if (num_shards_used <= NumThreads()) {
+ // Avoid a thread hop by running the root of the tree and one block on the
+ // main thread.
+ handle_range(0, total);
+ } else {
+ // Execute the root in the thread pool to avoid running work on more than
+ // numThreads() threads.
+ Schedule([=, &handle_range]() { handle_range(0, total); });
+ }
+ counter.Wait();
+}
+
void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit,
std::function<void(int64, int64)> fn) {
impl_->ParallelFor(total, cost_per_unit, std::move(fn));
diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h
index 74df7c84a4..e14ad7ac64 100644
--- a/tensorflow/core/lib/core/threadpool.h
+++ b/tensorflow/core/lib/core/threadpool.h
@@ -59,6 +59,20 @@ class ThreadPool {
// Schedules fn() for execution in the pool of threads.
void Schedule(std::function<void()> fn);
+ // Requires 0 < block_size <= total.
+ // Spawns k threads and calls fn(i*block_size, (i+1)*block_size) from the
+ // ith thread (i>=0). When (i+1)*block_size > total, fn(i*block_size, total)
+ // is called instead. k = NumShardsUsedByTransformRangeConcurrently(...).
+ // Note that when there aren't enough threads in the pool to achieve full
+ // parallelism, function calls will be automatically queued.
+ void TransformRangeConcurrently(const int64 block_size, const int64 total,
+ const std::function<void(int64, int64)>& fn);
+
+ // Returns the number of threads spawned by calling TransformRangeConcurrently
+ // with these parameters.
+ int NumShardsUsedByTransformRangeConcurrently(const int64 block_size,
+ const int64 total);
+
// ParallelFor shards the "total" units of work assuming each unit of work
// having roughly "cost_per_unit" cost, in cycles. Each unit of work is
// indexed 0, 1, ..., total - 1. Each shard contains 1 or more units of work
diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc
index 320f3ebb83..db996b783f 100644
--- a/tensorflow/core/lib/core/threadpool_test.cc
+++ b/tensorflow/core/lib/core/threadpool_test.cc
@@ -61,6 +61,67 @@ TEST(ThreadPool, DoWork) {
}
}
+void RunSharding(int64 block_size, int64 total, ThreadPool* threads) {
+ mutex mu;
+ int64 num_shards = 0;
+ int64 num_done_work = 0;
+ std::vector<bool> work(total, false);
+ threads->TransformRangeConcurrently(
+ block_size, total,
+ [=, &mu, &num_shards, &num_done_work, &work](int64 start, int64 end) {
+ VLOG(1) << "Shard [" << start << "," << end << ")";
+ EXPECT_GE(start, 0);
+ EXPECT_LE(end, total);
+ mutex_lock l(mu);
+ ++num_shards;
+ for (; start < end; ++start) {
+ EXPECT_FALSE(work[start]); // No duplicate
+ ++num_done_work;
+ work[start] = true;
+ }
+ });
+ LOG(INFO) << block_size << " " << total;
+ const int64 num_workers = (total + block_size - 1) / block_size;
+ EXPECT_EQ(num_done_work, total);
+ if (num_workers < threads->NumThreads()) {
+ // If the intention is to limit the parallelism explicitly, we'd
+ // better honor it. Ideally, even if per_thread_max_parallelism >
+ // num_workers, we should expect that Shard() implementation do
+ // not over-shard. Unfortunately, ThreadPoolDevice::parallelFor
+ // tends to over-shard.
+ EXPECT_LE(num_shards, 1 + num_workers);
+ }
+}
+
+// Adapted from work_sharder_test.cc
+TEST(SparseUtilsTest, TransformRangeConcurrently) {
+ ThreadPool threads(Env::Default(), "test", 16);
+ for (auto block_size : {1, 7, 10, 64, 100, 256, 1000, 9999}) {
+ for (auto diff : {0, 1, 11, 102, 1003, 10005, 1000007}) {
+ const int64 total = block_size + diff;
+ RunSharding(block_size, total, &threads);
+ }
+ }
+}
+
+TEST(SparseUtilsTest, NumShardsUsedByTransformRangeConcurrently) {
+ ThreadPool threads(Env::Default(), "test", 16);
+ EXPECT_EQ(1, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 3 /* total */));
+ EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 4 /* total */));
+ EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 5 /* total */));
+ EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 6 /* total */));
+ EXPECT_EQ(3, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 7 /* total */));
+ EXPECT_EQ(7, threads.NumShardsUsedByTransformRangeConcurrently(
+ 1 /* block_size */, 7 /* total */));
+ EXPECT_EQ(1, threads.NumShardsUsedByTransformRangeConcurrently(
+ 0 /* block_size */, 7 /* total */));
+}
+
TEST(ThreadPool, ParallelFor) {
Context outer_context(ContextKind::kThread);
// Make ParallelFor use as many threads as possible.
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc
index f93ebea771..e22adcd569 100644
--- a/tensorflow/core/lib/io/record_reader.cc
+++ b/tensorflow/core/lib/io/record_reader.cc
@@ -108,6 +108,59 @@ Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) {
return Status::OK();
}
+Status RecordReader::GetMetadata(Metadata* md) {
+ if (!md) {
+ return errors::InvalidArgument(
+ "Metadata object call to GetMetadata() was null");
+ }
+
+ // Compute the metadata of the TFRecord file if not cached.
+ if (!cached_metadata_) {
+ TF_RETURN_IF_ERROR(input_stream_->Reset());
+
+ int64 data_size = 0;
+ int64 entries = 0;
+
+ // Within the loop, we always increment offset positively, so this
+ // loop should be guaranteed to either return after reaching EOF
+ // or encountering an error.
+ uint64 offset = 0;
+ string record;
+ while (true) {
+ // Read header, containing size of data.
+ Status s = ReadChecksummed(offset, sizeof(uint64), &record);
+ if (!s.ok()) {
+ if (errors::IsOutOfRange(s)) {
+ // We should reach out of range when the record file is complete.
+ break;
+ }
+ return s;
+ }
+
+ // Read the length of the data.
+ const uint64 length = core::DecodeFixed64(record.data());
+
+ // Skip reading the actual data since we just want the number
+ // of records and the size of the data.
+ TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(length + kFooterSize));
+ offset += kHeaderSize + length + kFooterSize;
+
+ // Increment running stats.
+ data_size += length;
+ ++entries;
+ }
+
+ cached_metadata_.reset(new Metadata());
+ cached_metadata_->stats.entries = entries;
+ cached_metadata_->stats.data_size = data_size;
+ cached_metadata_->stats.file_size =
+ data_size + (kHeaderSize + kFooterSize) * entries;
+ }
+
+ md->stats = cached_metadata_->stats;
+ return Status::OK();
+}
+
Status RecordReader::ReadRecord(uint64* offset, string* record) {
// Position the input stream.
int64 curr_pos = input_stream_->Tell();
diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h
index 11af1366b0..17444660d4 100644
--- a/tensorflow/core/lib/io/record_reader.h
+++ b/tensorflow/core/lib/io/record_reader.h
@@ -66,6 +66,18 @@ class RecordReader {
static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
static const size_t kFooterSize = sizeof(uint32);
+ // Statistics (sizes are in units of bytes)
+ struct Stats {
+ int64 file_size = -1;
+ int64 data_size = -1;
+ int64 entries = -1; // Number of values
+ };
+
+ // Metadata for the TFRecord file.
+ struct Metadata {
+ Stats stats;
+ };
+
// Create a reader that will return log records from "*file".
// "*file" must remain live while this Reader is in use.
explicit RecordReader(
@@ -79,6 +91,17 @@ class RecordReader {
// OUT_OF_RANGE for end of file, or something else for an error.
Status ReadRecord(uint64* offset, string* record);
+ // Return the metadata of the Record file.
+ //
+ // The current implementation scans the file to completion,
+ // skipping over the data regions, to extract the metadata once
+ // on the first call to GetStats(). An improved implementation
+ // would change RecordWriter to write the metadata into TFRecord
+ // so that GetMetadata() could be a const method.
+ //
+ // 'metadata' must not be nullptr.
+ Status GetMetadata(Metadata* md);
+
private:
Status ReadChecksummed(uint64 offset, size_t n, string* result);
@@ -86,6 +109,8 @@ class RecordReader {
std::unique_ptr<InputStreamInterface> input_stream_;
bool last_read_failed_;
+ std::unique_ptr<Metadata> cached_metadata_;
+
TF_DISALLOW_COPY_AND_ASSIGN(RecordReader);
};
diff --git a/tensorflow/core/lib/io/record_reader_writer_test.cc b/tensorflow/core/lib/io/record_reader_writer_test.cc
index 13bea1f8f1..a88d34d293 100644
--- a/tensorflow/core/lib/io/record_reader_writer_test.cc
+++ b/tensorflow/core/lib/io/record_reader_writer_test.cc
@@ -147,6 +147,13 @@ TEST(RecordReaderWriterTest, TestBasics) {
EXPECT_EQ("abc", record);
TF_CHECK_OK(reader.ReadRecord(&offset, &record));
EXPECT_EQ("defg", record);
+
+ io::RecordReader::Metadata md;
+ TF_ASSERT_OK(reader.GetMetadata(&md));
+ EXPECT_EQ(2, md.stats.entries);
+ EXPECT_EQ(7, md.stats.data_size);
+ // Two entries have 16 bytes of header/footer each.
+ EXPECT_EQ(39, md.stats.file_size);
}
}
}
diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.cc b/tensorflow/core/lib/jpeg/jpeg_mem.cc
index 50ed8bdb3b..f7a359eb5b 100644
--- a/tensorflow/core/lib/jpeg/jpeg_mem.cc
+++ b/tensorflow/core/lib/jpeg/jpeg_mem.cc
@@ -152,7 +152,9 @@ uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) {
cinfo.scale_denom = ratio;
cinfo.dct_method = flags.dct_method;
- jpeg_start_decompress(&cinfo);
+ // Determine the output image size before attempting decompress to prevent
+ // OOM'ing doing the decompress
+ jpeg_calc_output_dimensions(&cinfo);
int64 total_size = static_cast<int64>(cinfo.output_height) *
static_cast<int64>(cinfo.output_width);
@@ -170,6 +172,8 @@ uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) {
return nullptr;
}
+ jpeg_start_decompress(&cinfo);
+
JDIMENSION target_output_width = cinfo.output_width;
JDIMENSION target_output_height = cinfo.output_height;
JDIMENSION skipped_scanlines = 0;
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 7dbb18aa5d..c9f80df5e4 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -1531,37 +1531,6 @@ REGISTER_OP("Size")
.Attr("out_type: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ScalarShape);
-namespace {
-
-// This SliceHelper processes the output shape of the `slice`
-// when the tensor of `sizes` is available.
-template <typename T>
-Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
- const Tensor* sizes_value,
- std::vector<DimensionHandle>* dims) {
- auto sizes_vec = sizes_value->vec<T>();
- for (int i = 0; i < sizes_value->NumElements(); ++i) {
- DimensionHandle dim = c->Dim(c->input(0), i);
- if (sizes_vec(i) != -1) {
- auto dim_val = c->Value(dim);
- if (sizes_vec(i) < 0) {
- return errors::InvalidArgument(
- "Out of bounds slicing on dimension ", i, " of length ", dim_val,
- ": sizes vector cannot be < -1, but was ", sizes_vec(i));
- }
-
- dims->emplace_back(c->MakeDim(sizes_vec(i)));
- } else {
- DimensionHandle result;
- TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result));
- dims->emplace_back(result);
- }
- }
-
- return Status::OK();
-}
-} // namespace
-
// --------------------------------------------------------------------------
REGISTER_OP("Slice")
.Input("input: T")
@@ -1570,83 +1539,22 @@ REGISTER_OP("Slice")
.Output("output: T")
.Attr("T: type")
.Attr("Index: {int32,int64}")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle input = c->input(0);
- ShapeHandle begin_shape;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
- ShapeHandle sizes_shape;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape));
+ .SetShapeFn(shape_inference::SliceShape);
- // Merge to check compatibility of begin and sizes tensors.
- TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape));
-
- DimensionHandle ndims = c->Dim(begin_shape, 0);
- if (c->ValueKnown(ndims)) {
- TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input));
- }
-
- // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known
- // values, even though the `begin` value does not represent a shape.
- ShapeHandle begin_value;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value));
-
- // We check the tensor value here and will only use
- // `MakeShapeFromShapeTensor` when `sizes_value` is null.
- // The reason is that `sizes`might contain -1, which can't
- // be represented (-1 in the ShapeHandle would mean "unknown".
- const Tensor* sizes_value = c->input_tensor(2);
-
- if (sizes_value != nullptr) {
- TF_RETURN_IF_ERROR(
- c->WithRank(begin_value, sizes_value->NumElements(), &begin_value));
- std::vector<DimensionHandle> dims;
- // If the begin and sizes tensors are available, then
- // we can be precise about the shape of the output.
- if (sizes_value->dtype() == DT_INT64) {
- TF_RETURN_IF_ERROR(
- SliceHelper<int64>(c, begin_value, sizes_value, &dims));
- } else {
- TF_RETURN_IF_ERROR(
- SliceHelper<int32>(c, begin_value, sizes_value, &dims));
- }
-
- c->set_output(0, c->MakeShape(dims));
- return Status::OK();
- } else {
- // In case `sizes` is not available (`sizes_value` is null),
- // we could try to use `MakeShapeFromShapeTensor` here.
- // If sizes contain -1, we will simply consider it as `Unknown`.
- // This is less than ideal but still an improvement of shape inference.
- // The following is an example that returns [None, 1, None] with this
- // code path:
- // z = tf.zeros((1, 2, 3))
- // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1])
- // m.get_shape().as_list()
- ShapeHandle sizes_value;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value));
- if (c->RankKnown(sizes_value)) {
- TF_RETURN_IF_ERROR(
- c->WithRank(begin_value, c->Rank(sizes_value), &begin_value));
- std::vector<DimensionHandle> dims;
- dims.reserve(c->Rank(sizes_value));
- for (int i = 0; i < c->Rank(sizes_value); ++i) {
- dims.emplace_back(c->Dim(sizes_value, i));
- }
- c->set_output(0, c->MakeShape(dims));
- return Status::OK();
- }
-
- // We might know the rank of the input.
- if (c->RankKnown(input)) {
- c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
- return Status::OK();
- } else {
- return shape_inference::UnknownShape(c);
- }
- }
-
- return Status::OK();
- });
+#ifdef INTEL_MKL
+REGISTER_OP("_MklSlice")
+ .Input("input: T")
+ .Input("begin: Index")
+ .Input("size: Index")
+ .Input("mkl_input: uint8")
+ .Input("mkl_begin: uint8")
+ .Input("mkl_size: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("T: type")
+ .Attr("Index: {int32,int64}")
+ .SetShapeFn(shape_inference::SliceShape);
+#endif
REGISTER_OP("StridedSlice")
.Input("input: T")
@@ -2595,6 +2503,116 @@ REGISTER_OP("ExtractImagePatches")
// --------------------------------------------------------------------------
+// To enable rates, uncomment all lines commented below and use ksize_*_eff
+// as the second parameter of all GetWindowedOutputSizeVerbose calls instead
+// of ksize_*.
+REGISTER_OP("ExtractVolumePatches")
+ .Input("input: T")
+ .Output("patches: T")
+ .Attr("ksizes: list(int) >= 5")
+ .Attr("strides: list(int) >= 5")
+ /* .Attr("rates: list(int) >= 5") */
+ .Attr("T: realnumbertype")
+ .Attr(GetPaddingAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle input_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
+
+ std::vector<int32> ksizes;
+ TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
+ if (ksizes.size() != 5) {
+ return errors::InvalidArgument(
+ "ExtractVolumePatches requires the ksizes attribute to contain 5 "
+ "values, but got: ",
+ ksizes.size());
+ }
+
+ std::vector<int32> strides;
+ TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
+ if (strides.size() != 5) {
+ return errors::InvalidArgument(
+ "ExtractVolumePatches requires the stride attribute to contain 5 "
+ "values, but got: ",
+ strides.size());
+ }
+
+ /*
+ // TODO(hsgkim): Enable rates.
+ // See extract_volume_patches_op.cc for why rates are disabled now.
+
+ std::vector<int32> rates;
+ TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
+ if (rates.size() != 5) {
+ return errors::InvalidArgument(
+ "ExtractVolumePatches requires the rates attribute to contain 5 "
+ "values, but got: ",
+ rates.size());
+ }
+ */
+
+ int32 ksize_planes = ksizes[1];
+ int32 ksize_rows = ksizes[2];
+ int32 ksize_cols = ksizes[3];
+
+ int32 stride_planes = strides[1];
+ int32 stride_rows = strides[2];
+ int32 stride_cols = strides[3];
+
+ /*
+ int32 rate_planes = rates[1];
+ int32 rate_rows = rates[2];
+ int32 rate_cols = rates[3];
+
+ int32 ksize_planes_eff = ksize_planes +
+ (ksize_planes - 1) * (rate_planes - 1);
+ int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
+ int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
+ */
+
+ DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
+ DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
+ DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
+ DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
+ DimensionHandle output_depth_dim;
+ TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, 4),
+ ksize_planes * ksize_rows * ksize_cols,
+ &output_depth_dim));
+
+ if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) ||
+ !c->ValueKnown(in_cols_dim)) {
+ ShapeHandle output_shape =
+ c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
+ InferenceContext::kUnknownDim, output_depth_dim});
+ c->set_output(0, output_shape);
+ return Status::OK();
+ }
+ auto in_planes = c->Value(in_planes_dim);
+ auto in_rows = c->Value(in_rows_dim);
+ auto in_cols = c->Value(in_cols_dim);
+
+ Padding padding;
+ TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
+
+ int64 output_planes, output_rows, output_cols;
+ int64 padding_before, padding_after;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_planes, ksize_planes, stride_planes, padding, &output_planes,
+ &padding_before, &padding_after));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_rows, ksize_rows, stride_rows, padding, &output_rows,
+ &padding_before, &padding_after));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_cols, ksize_cols, stride_cols, padding, &output_cols,
+ &padding_before, &padding_after));
+ ShapeHandle output_shape =
+ c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols,
+ output_depth_dim});
+ c->set_output(0, output_shape);
+ return Status::OK();
+ });
+
+// --------------------------------------------------------------------------
+
REGISTER_OP("Bitcast")
.Input("input: T")
.Output("output: type")
@@ -2916,6 +2934,34 @@ Status ScatterNdShape(InferenceContext* c) {
} // namespace
+REGISTER_OP("UpperBound")
+ .Input("sorted_inputs: T")
+ .Input("values: T")
+ .Output("output: out_type")
+ .Attr("T: type")
+ .Attr("out_type: {int32, int64} = DT_INT32")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
+ c->set_output(0, c->input(1));
+ return Status::OK();
+ });
+
+REGISTER_OP("LowerBound")
+ .Input("sorted_inputs: T")
+ .Input("values: T")
+ .Output("output: out_type")
+ .Attr("T: type")
+ .Attr("out_type: {int32, int64} = DT_INT32")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
+ c->set_output(0, c->input(1));
+ return Status::OK();
+ });
+
REGISTER_OP("ScatterNd")
.Input("indices: Tindices")
.Input("updates: T")
diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc
index 7c4184bff4..b8cf538554 100644
--- a/tensorflow/core/ops/boosted_trees_ops.cc
+++ b/tensorflow/core/ops/boosted_trees_ops.cc
@@ -180,6 +180,8 @@ REGISTER_OP("BoostedTreesMakeStatsSummary")
return Status::OK();
});
+// TODO(nponomareva): when/if creating the new op for unbucketized data, rename
+// bucketized_features to features.
REGISTER_OP("BoostedTreesPredict")
.Input("tree_ensemble_handle: resource")
.Input("bucketized_features: num_bucketized_features * int32")
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 57c6bda98b..43c14d83b5 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -21532,6 +21532,421 @@ op {
}
}
op {
+ name: "ExperimentalAssertNextDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "transformations"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalCSVDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "compression_type"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "buffer_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "header"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "field_delim"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "use_quote_delim"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "na_value"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "select_cols"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "record_defaults"
+ type_list_attr: "output_types"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalDirectedInterleaveDataset"
+ input_arg {
+ name: "selector_input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "data_input_datasets"
+ type: DT_VARIANT
+ number_attr: "N"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalFunctionBufferingResource"
+ input_arg {
+ name: "string_arg"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "target_device"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "buffer_size"
+ type: "int"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceGetNext"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceReset"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIdentityIndexedDataset"
+ input_arg {
+ name: "size"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIgnoreErrorsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalIndexedDatasetGet"
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "index"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIndexedDatasetMaterialize"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIteratorGetDevice"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "device"
+ type: DT_STRING
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalLMDBDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalMaterializedIndexDatasetHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "thread_pool"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "num_threads"
+ type: "int"
+ }
+ attr {
+ name: "max_intra_op_parallelism"
+ type: "int"
+ default_value {
+ i: 1
+ }
+ }
+ attr {
+ name: "display_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalUniqueDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Expm1"
input_arg {
name: "x"
@@ -21902,6 +22317,59 @@ op {
}
}
op {
+ name: "ExtractVolumePatches"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "patches"
+ type_attr: "T"
+ }
+ attr {
+ name: "ksizes"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 5
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 5
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+}
+op {
name: "FFT"
input_arg {
name: "input"
@@ -24052,6 +24520,85 @@ op {
}
}
op {
+ name: "FusedBatchNorm"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scale"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "offset"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "mean"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "variance"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "batch_mean"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "batch_variance"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "reserve_space_1"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "reserve_space_2"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "epsilon"
+ type: "float"
+ default_value {
+ f: 0.0001
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
+ }
+ attr {
+ name: "is_training"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "FusedBatchNormGrad"
input_arg {
name: "y_backprop"
@@ -24125,6 +24672,85 @@ op {
}
}
op {
+ name: "FusedBatchNormGrad"
+ input_arg {
+ name: "y_backprop"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scale"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "reserve_space_1"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "reserve_space_2"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "x_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "scale_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "offset_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "reserve_space_3"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "reserve_space_4"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "epsilon"
+ type: "float"
+ default_value {
+ f: 0.0001
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
+ }
+ attr {
+ name: "is_training"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "FusedBatchNormGradV2"
input_arg {
name: "y_backprop"
@@ -24292,6 +24918,179 @@ op {
}
}
op {
+ name: "FusedBatchNormGradV2"
+ input_arg {
+ name: "y_backprop"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scale"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "reserve_space_1"
+ type_attr: "U"
+ }
+ input_arg {
+ name: "reserve_space_2"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "x_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "scale_backprop"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "offset_backprop"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "reserve_space_3"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "reserve_space_4"
+ type_attr: "U"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "U"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "epsilon"
+ type: "float"
+ default_value {
+ f: 0.0001
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
+ }
+ attr {
+ name: "is_training"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
+ name: "FusedBatchNormV2"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scale"
+ type_attr: "U"
+ }
+ input_arg {
+ name: "offset"
+ type_attr: "U"
+ }
+ input_arg {
+ name: "mean"
+ type_attr: "U"
+ }
+ input_arg {
+ name: "variance"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "batch_mean"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "batch_variance"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "reserve_space_1"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "reserve_space_2"
+ type_attr: "U"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "U"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "epsilon"
+ type: "float"
+ default_value {
+ f: 0.0001
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ }
+ attr {
+ name: "is_training"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "FusedBatchNormV2"
input_arg {
name: "x"
@@ -24339,6 +25138,7 @@ op {
allowed_values {
list {
type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
}
}
@@ -24449,6 +25249,12 @@ op {
default_value {
s: "NHWC"
}
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
}
attr {
name: "is_training"
@@ -29388,6 +30194,38 @@ op {
}
}
op {
+ name: "LowerBound"
+ input_arg {
+ name: "sorted_inputs"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "out_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "out_type"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "MakeIterator"
input_arg {
name: "dataset"
@@ -35241,6 +36079,134 @@ op {
is_commutative: true
}
op {
+ name: "MultiDeviceIterator"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "devices"
+ type: "list(string)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorFromStringHandle"
+ input_arg {
+ name: "string_handle"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorGetNextFromShard"
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "shard_num"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "incarnation_id"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorInit"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "max_buffer_size"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "incarnation_id"
+ type: DT_INT64
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorToStringHandle"
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "string_handle"
+ type: DT_STRING
+ }
+ is_stateful: true
+}
+op {
name: "Multinomial"
input_arg {
name: "logits"
@@ -38880,6 +39846,30 @@ op {
is_stateful: true
}
op {
+ name: "PrintV2"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ attr {
+ name: "output_stream"
+ type: "string"
+ default_value {
+ s: "stderr"
+ }
+ allowed_values {
+ list {
+ s: "stdout"
+ s: "stderr"
+ s: "log(info)"
+ s: "log(warning)"
+ s: "log(error)"
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "PriorityQueue"
output_arg {
name: "handle"
@@ -44281,6 +45271,59 @@ op {
is_stateful: true
}
op {
+ name: "ReduceDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "initial_state"
+ type_list_attr: "Tstate"
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Tstate"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "ReduceJoin"
input_arg {
name: "inputs"
@@ -59848,6 +60891,29 @@ op {
}
}
op {
+ name: "Softplus"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "activations"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "SoftplusGrad"
input_arg {
name: "gradients"
@@ -59984,6 +61050,33 @@ op {
}
}
op {
+ name: "SoftplusGrad"
+ input_arg {
+ name: "gradients"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprops"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "Softsign"
input_arg {
name: "features"
@@ -60104,6 +61197,29 @@ op {
}
}
op {
+ name: "Softsign"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "activations"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "SoftsignGrad"
input_arg {
name: "gradients"
@@ -60240,6 +61356,33 @@ op {
}
}
op {
+ name: "SoftsignGrad"
+ input_arg {
+ name: "gradients"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprops"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "SpaceToBatch"
input_arg {
name: "input"
@@ -70188,6 +71331,43 @@ op {
}
}
op {
+ name: "StringFormat"
+ input_arg {
+ name: "inputs"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "T"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "template"
+ type: "string"
+ default_value {
+ s: "%s"
+ }
+ }
+ attr {
+ name: "placeholder"
+ type: "string"
+ default_value {
+ s: "%s"
+ }
+ }
+ attr {
+ name: "summarize"
+ type: "int"
+ default_value {
+ i: 3
+ }
+ }
+}
+op {
name: "StringJoin"
input_arg {
name: "inputs"
@@ -70224,6 +71404,30 @@ op {
}
}
op {
+ name: "StringLength"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_INT32
+ }
+ attr {
+ name: "unit"
+ type: "string"
+ default_value {
+ s: "BYTE"
+ }
+ allowed_values {
+ list {
+ s: "BYTE"
+ s: "UTF8_CHAR"
+ }
+ }
+ }
+}
+op {
name: "StringSplit"
input_arg {
name: "input"
@@ -74175,6 +75379,17 @@ op {
}
}
op {
+ name: "UnicodeScript"
+ input_arg {
+ name: "input"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "output"
+ type: DT_INT32
+ }
+}
+op {
name: "UniformCandidateSampler"
input_arg {
name: "true_classes"
@@ -75267,6 +76482,38 @@ op {
is_stateful: true
}
op {
+ name: "UpperBound"
+ input_arg {
+ name: "sorted_inputs"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "out_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "out_type"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "VarHandleOp"
output_arg {
name: "resource"
@@ -75602,9 +76849,21 @@ op {
type: DT_VARIANT
}
input_arg {
- name: "window_size"
+ name: "size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "shift"
type: DT_INT64
}
+ input_arg {
+ name: "stride"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "drop_remainder"
+ type: DT_BOOL
+ }
output_arg {
name: "handle"
type: DT_VARIANT
@@ -75841,6 +77100,62 @@ op {
is_stateful: true
}
op {
+ name: "Xdivy"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
+ name: "Xlogy"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "ZerosLike"
input_arg {
name: "x"
diff --git a/tensorflow/core/ops/cudnn_rnn_ops.cc b/tensorflow/core/ops/cudnn_rnn_ops.cc
index f78f7a897a..f84142c992 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops.cc
@@ -37,7 +37,6 @@ using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
-
REGISTER_OP("CudnnRNNParamsSize")
.Input("num_layers: int32")
.Input("num_units: int32")
@@ -52,11 +51,16 @@ REGISTER_OP("CudnnRNNParamsSize")
.Attr("seed2: int = 0")
.Output("params_size: S")
.SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused;
+ // num_layers, num_units, and input_size should be scalars.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+
c->set_output(0, c->Vector(1));
return Status::OK();
});
-
REGISTER_OP("CudnnRNN")
.Input("input: T")
.Input("input_h: T")
@@ -248,7 +252,6 @@ REGISTER_OP("CudnnRNNParamsToCanonical")
return Status::OK();
});
-
REGISTER_OP("CudnnRNNCanonicalToParams")
.Input("num_layers: int32")
.Input("num_units: int32")
diff --git a/tensorflow/core/ops/cudnn_rnn_ops_test.cc b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
index 2dd867561b..13c3b933f4 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops_test.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
@@ -26,7 +26,16 @@ namespace tensorflow {
TEST(CudnnRNNOpsTest, ParamsSize_ShapeFn) {
ShapeInferenceTestOp op("CudnnRNNParamsSize");
- INFER_OK(op, "[1];[1];[1]", "[1]");
+ INFER_OK(op, "[];[];[]", "[1]");
+ INFER_OK(op, "?;[];[]", "[1]");
+ INFER_OK(op, "[];?;[]", "[1]");
+ INFER_OK(op, "[];[];?", "[1]");
+ INFER_OK(op, "[];?;?", "[1]");
+ INFER_OK(op, "?;?;?", "[1]");
+
+ INFER_ERROR("Shape must be rank 0 ", op, "[1,2];?;[]");
+ INFER_ERROR("Shape must be rank 0 ", op, "?;[2];[]");
+ INFER_ERROR("Shape must be rank 0 ", op, "?;?;[1]");
}
TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) {
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 7d9e7b2d3f..71f4cc3c4c 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -396,14 +396,20 @@ REGISTER_OP("FilterByLastComponentDataset")
REGISTER_OP("WindowDataset")
.Input("input_dataset: variant")
- .Input("window_size: int64")
+ .Input("size: int64")
+ .Input("shift: int64")
+ .Input("stride: int64")
+ .Input("drop_remainder: bool")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
- // batch_size should be a scalar.
+ // size, shift, stride, and drop_remainder should be scalars.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
return shape_inference::ScalarShape(c);
});
@@ -750,6 +756,19 @@ REGISTER_OP("DatasetToSingleElement")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(IteratorGetNextShapeFn);
+REGISTER_OP("ReduceDataset")
+ .Input("input_dataset: variant")
+ .Input("initial_state: Tstate")
+ .Input("other_arguments: Targuments")
+ .Output("components: output_types")
+ .Attr("f: func")
+ .Attr("Tstate: list(type) >= 1")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .Attr("use_inter_op_parallelism: bool = true")
+ .SetShapeFn(IteratorGetNextShapeFn);
+
REGISTER_OP("IteratorToStringHandle")
.Input("resource_handle: resource")
.Output("string_handle: string")
@@ -926,4 +945,41 @@ REGISTER_OP("MapDefun")
return Status::OK();
});
+REGISTER_OP("MultiDeviceIterator")
+ .Output("handle: resource")
+ .Attr("devices: list(string) >= 1")
+ .Attr("shared_name: string")
+ .Attr("container: string")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("MultiDeviceIteratorInit")
+ .Input("dataset: variant")
+ .Input("multi_device_iterator: resource")
+ .Input("max_buffer_size: int64")
+ .Output("incarnation_id: int64")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("MultiDeviceIteratorGetNextFromShard")
+ .Input("multi_device_iterator: resource")
+ .Input("shard_num: int32")
+ .Input("incarnation_id: int64")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(IteratorGetNextShapeFn);
+
+REGISTER_OP("MultiDeviceIteratorToStringHandle")
+ .Input("multi_device_iterator: resource")
+ .Output("string_handle: string")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("MultiDeviceIteratorFromStringHandle")
+ .Input("string_handle: string")
+ .Output("multi_device_iterator: resource")
+ .Attr("output_types: list(type) >= 0 = []")
+ .Attr("output_shapes: list(shape) >= 0 = []")
+ .SetShapeFn(shape_inference::ScalarShape);
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc
new file mode 100644
index 0000000000..f6bd5dce26
--- /dev/null
+++ b/tensorflow/core/ops/experimental_dataset_ops.cc
@@ -0,0 +1,207 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("ExperimentalDirectedInterleaveDataset")
+ .Input("selector_input_dataset: variant")
+ .Input("data_input_datasets: N * variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .Attr("N: int >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalCSVDataset")
+ .Input("filenames: string")
+ .Input("compression_type: string")
+ .Input("buffer_size: int64")
+ .Input("header: bool")
+ .Input("field_delim: string")
+ .Input("use_quote_delim: bool")
+ .Input("na_value: string")
+ .Input("select_cols: int64")
+ .Input("record_defaults: output_types")
+ .Output("handle: variant")
+ .Attr("output_types: list({float,double,int32,int64,string}) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // `filenames` must be a scalar or a vector.
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
+ // `compression_type`, `buffer_size`, `header`, `field_delim`,
+ // `use_quote_delim`, `na_value` must be scalars
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
+ // `select_cols` must be a vector
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused));
+ // `record_defaults` must be lists of scalars
+ for (size_t i = 8; i < c->num_inputs(); ++i) {
+ shape_inference::ShapeHandle v;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
+ if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
+ return errors::InvalidArgument(
+ "Shape of a default must be a length-0 or length-1 vector, or a "
+ "scalar.");
+ }
+ }
+ return shape_inference::ScalarShape(c);
+ });
+
+REGISTER_OP("ExperimentalIgnoreErrorsDataset")
+ .Input("input_dataset: variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalUniqueDataset")
+ .Input("input_dataset: variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalIteratorGetDevice")
+ .Input("resource: resource")
+ .Output("device: string")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalFunctionBufferingResource")
+ .Input("string_arg: string")
+ .Input("target_device: string")
+ .Output("resource: resource")
+ .Attr("shared_name: string")
+ .Attr("container: string")
+ .Attr("f: func")
+ .Attr("buffer_size: int")
+ .Attr("output_types: list(type)")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("ExperimentalFunctionBufferingResourceGetNext")
+ .Input("function_buffer_resource: resource")
+ .Attr("output_types: list(type)")
+ .Output("output: output_types")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("ExperimentalFunctionBufferingResourceReset")
+ .Input("function_buffer_resource: resource")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("ExperimentalThreadPoolDataset")
+ .Input("input_dataset: variant")
+ .Input("thread_pool: resource")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalThreadPoolHandle")
+ .Output("handle: resource")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Attr("num_threads: int")
+ .Attr("max_intra_op_parallelism: int = 1")
+ .Attr("display_name: string")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''");
+
+REGISTER_OP("ExperimentalAssertNextDataset")
+ .Input("input_dataset: variant")
+ .Input("transformations: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // transformations should be a vector.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
+ return shape_inference::ScalarShape(c);
+ });
+
+REGISTER_OP("ExperimentalLMDBDataset")
+ .Input("filenames: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalIdentityIndexedDataset")
+ .Input("size: uint64")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(
+ shape_inference::ScalarShape); // TODO(saeta): check input shapes.
+
+///////////////////////////////////////////////////////////////////////////////
+// IndexedDataset Internals
+///////////////////////////////////////////////////////////////////////////////
+
+// Creates the handle.
+REGISTER_OP("ExperimentalMaterializedIndexDatasetHandle")
+ .Output("handle: resource")
+ .Attr("container: string")
+ .Attr("shared_name: string")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+// Actually materialize the materialize handle.
+REGISTER_OP("ExperimentalIndexedDatasetMaterialize")
+ .Input("dataset: variant")
+ .Input("materialized: resource")
+ .SetShapeFn(shape_inference::NoOutputs);
+
+namespace {
+
+Status GetShapeFn(shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as `output_types` (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+REGISTER_OP("ExperimentalIndexedDatasetGet")
+ .Input("materialized: resource")
+ .Input("index: uint64")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(GetShapeFn);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/logging_ops.cc b/tensorflow/core/ops/logging_ops.cc
index 639d211767..2034d3601b 100644
--- a/tensorflow/core/ops/logging_ops.cc
+++ b/tensorflow/core/ops/logging_ops.cc
@@ -20,6 +20,8 @@ limitations under the License.
namespace tensorflow {
+using shape_inference::InferenceContext;
+
REGISTER_OP("Assert")
.Input("condition: bool")
.Input("data: T")
@@ -44,6 +46,23 @@ REGISTER_OP("Print")
WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("Print");
+REGISTER_OP("PrintV2")
+ .Input("input: string")
+ .SetIsStateful()
+ .Attr(
+ "output_stream: {'stdout', 'stderr', 'log(info)', "
+ "'log(warning)', 'log(error)'} = 'stderr'")
+ .SetShapeFn([](InferenceContext* c) {
+ // Make sure that the input is a scalar.
+ if (c->Rank(c->input(0)) != 0) {
+ return errors::InvalidArgument("input must be a scalar, but has rank: ",
+ c->Rank(c->input(0)));
+ }
+ return Status::OK();
+ });
+
+WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("PrintV2");
+
// ----------------------------------------------------------------------------
// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
// inputs or outputs in various ways.
diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc
index 07f876cb90..55dcc50325 100644
--- a/tensorflow/core/ops/math_grad.cc
+++ b/tensorflow/core/ops/math_grad.cc
@@ -549,6 +549,40 @@ Status PowGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("Pow", PowGrad);
+Status XlogyGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForBinaryCwise(g, {
+ {{"zeros"}, "ZerosLike", {"x"}},
+ {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
+ {{"is_zero_cast"}, "Cast", {"is_x_zero"},
+ {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
+ {{"safe_logy"}, "Xlogy", {"is_zero_cast", "y"}},
+ {{"xlogygrad"}, "Xdivy", {"x", "y"}},
+ {{"gx"}, "Mul", {"safe_logy", "dz"}},
+ {{"gy"}, "Mul", {"xlogygrad", "dz"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Xlogy", XlogyGrad);
+
+Status XdivyGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForBinaryCwise(g, {
+ {{"zeros"}, "ZerosLike", {"x"}},
+ {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
+ {{"is_zero_cast"}, "Cast", {"is_x_zero"},
+ {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
+ {{"safe_divy"}, "Xdivy", {"is_zero_cast", "y"}},
+ {{"y2"}, "Square", {"y"}},
+ {{"negy2"}, "Neg", {"y2"}},
+ {{"xdivygrad"}, "Xdivy", {"x", "negy2"}},
+ {{"gx"}, "Mul", {"safe_divy", "dz"}},
+ {{"gy"}, "Mul", {"xdivygrad", "dz"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Xdivy", XdivyGrad);
+
Status MaximumMinimumGradHelper(const string& comparator,
const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc
index 5ee79809ac..9fc6b34147 100644
--- a/tensorflow/core/ops/math_grad_test.cc
+++ b/tensorflow/core/ops/math_grad_test.cc
@@ -909,6 +909,46 @@ TEST_F(MathGradTest, ComplexPow) {
}
#endif // TENSORFLOW_USE_SYCL
+TEST_F(MathGradTest, Xlogy) {
+ auto x = test::AsTensor<float>({0.f, 0.f, 2.f, 3.f, 4.f, 5.f},
+ TensorShape({2, 3}));
+ auto y = test::AsTensor<float>({.5f, 2.f}, TensorShape({2, 1}));
+ Tensor dx;
+ Tensor dy;
+ auto g = [](float x, float y) -> float { return x == 0. ? 0. : std::log(y); };
+ auto h = [](float x, float y) -> float { return x == 0. ? 0. : x / y; };
+ SymGrad("Xlogy", x, y, &dx, &dy);
+ test::ExpectClose(
+ dx, test::AsTensor<float>({g(0.f, .5f), g(0.f, 0.f), g(2.f, .5f),
+ g(3.f, 2.f), g(4.f, 2.f), g(5.f, 2.f)},
+ TensorShape({2, 3})));
+ test::ExpectClose(
+ dy, test::AsTensor<float>({h(0.f, .5f) + h(0.f, 0.f) + h(2.f, .5f),
+ h(3.f, 2.f) + h(4.f, 2.f) + h(5.f, 2.f)},
+ TensorShape({2, 1})));
+}
+
+TEST_F(MathGradTest, Xdivy) {
+ auto x = test::AsTensor<float>({0.f, 0.f, 2.f, 3.f, 4.f, 5.f},
+ TensorShape({2, 3}));
+ auto y = test::AsTensor<float>({.5f, 2.f}, TensorShape({2, 1}));
+ Tensor dx;
+ Tensor dy;
+ auto g = [](float x, float y) -> float { return x == 0. ? 0. : 1 / y; };
+ auto h = [](float x, float y) -> float {
+ return x == 0. ? 0. : -x / (y * y);
+ };
+ SymGrad("Xdivy", x, y, &dx, &dy);
+ test::ExpectClose(
+ dx, test::AsTensor<float>({g(0.f, .5f), g(0.f, 0.f), g(2.f, .5f),
+ g(3.f, 2.f), g(4.f, 2.f), g(5.f, 2.f)},
+ TensorShape({2, 3})));
+ test::ExpectClose(
+ dy, test::AsTensor<float>({h(0.f, .5f) + h(0.f, 0.f) + h(2.f, .5f),
+ h(3.f, 2.f) + h(4.f, 2.f) + h(5.f, 2.f)},
+ TensorShape({2, 1})));
+}
+
TEST_F(MathGradTest, Maximum) {
auto x = test::AsTensor<float>({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f},
TensorShape({2, 3}));
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 717263a9b0..3eff728f03 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -429,6 +429,20 @@ Returns (x - y)(x - y) element-wise.
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
)doc");
+REGISTER_OP("Xlogy")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {half, float, double, complex64, complex128}")
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
+
+REGISTER_OP("Xdivy")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {half, float, double, complex64, complex128}")
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
+
#undef BINARY_FEWER
#undef BINARY_MORE
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 2485fa4717..d1d81b27cc 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -178,7 +178,7 @@ REGISTER_OP("FusedBatchNorm")
.Output("reserve_space_2: T")
.Attr("T: {float}")
.Attr("epsilon: float = 0.0001")
- .Attr("data_format: string = 'NHWC'")
+ .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormShape);
@@ -196,7 +196,7 @@ REGISTER_OP("FusedBatchNormV2")
.Attr("T: {half, bfloat16, float}")
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
- .Attr("data_format: string = 'NHWC'")
+ .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormShape);
@@ -213,7 +213,7 @@ REGISTER_OP("FusedBatchNormGrad")
.Output("reserve_space_4: T")
.Attr("T: {float}")
.Attr("epsilon: float = 0.0001")
- .Attr("data_format: string = 'NHWC'")
+ .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormGradShape);
@@ -231,7 +231,7 @@ REGISTER_OP("FusedBatchNormGradV2")
.Attr("T: {half, bfloat16, float}")
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
- .Attr("data_format: string = 'NHWC'")
+ .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormGradShape);
@@ -1009,32 +1009,30 @@ REGISTER_OP("SeluGrad")
.Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
-// TODO(b/111515541): change T to {half, bfloat16, float, double}
REGISTER_OP("Softplus")
.Input("features: T")
.Output("activations: T")
- .Attr("T: realnumbertype")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("SoftplusGrad")
.Input("gradients: T")
.Input("features: T")
.Output("backprops: T")
- .Attr("T: realnumbertype")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
-// TODO(b/111515541): change T to {half, bfloat16, float, double}
REGISTER_OP("Softsign")
.Input("features: T")
.Output("activations: T")
- .Attr("T: realnumbertype")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("SoftsignGrad")
.Input("gradients: T")
.Input("features: T")
.Output("backprops: T")
- .Attr("T: realnumbertype")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
// --------------------------------------------------------------------------
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 190f6aaa5b..abee803889 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -10039,6 +10039,421 @@ op {
}
}
op {
+ name: "ExperimentalAssertNextDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "transformations"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalCSVDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "compression_type"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "buffer_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "header"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "field_delim"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "use_quote_delim"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "na_value"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "select_cols"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "record_defaults"
+ type_list_attr: "output_types"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalDirectedInterleaveDataset"
+ input_arg {
+ name: "selector_input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "data_input_datasets"
+ type: DT_VARIANT
+ number_attr: "N"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalFunctionBufferingResource"
+ input_arg {
+ name: "string_arg"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "target_device"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "buffer_size"
+ type: "int"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceGetNext"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceReset"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIdentityIndexedDataset"
+ input_arg {
+ name: "size"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIgnoreErrorsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalIndexedDatasetGet"
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "index"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIndexedDatasetMaterialize"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIteratorGetDevice"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "device"
+ type: DT_STRING
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalLMDBDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalMaterializedIndexDatasetHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "thread_pool"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "num_threads"
+ type: "int"
+ }
+ attr {
+ name: "max_intra_op_parallelism"
+ type: "int"
+ default_value {
+ i: 1
+ }
+ }
+ attr {
+ name: "display_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalUniqueDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Expm1"
input_arg {
name: "x"
@@ -10187,6 +10602,59 @@ op {
}
}
op {
+ name: "ExtractVolumePatches"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "patches"
+ type_attr: "T"
+ }
+ attr {
+ name: "ksizes"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 5
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 5
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+}
+op {
name: "FFT"
input_arg {
name: "input"
@@ -11406,6 +11874,12 @@ op {
default_value {
s: "NHWC"
}
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
}
attr {
name: "is_training"
@@ -11479,6 +11953,12 @@ op {
default_value {
s: "NHWC"
}
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
}
attr {
name: "is_training"
@@ -11563,6 +12043,12 @@ op {
default_value {
s: "NHWC"
}
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
}
attr {
name: "is_training"
@@ -11647,6 +12133,12 @@ op {
default_value {
s: "NHWC"
}
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
}
attr {
name: "is_training"
@@ -14536,6 +15028,38 @@ op {
}
}
op {
+ name: "LowerBound"
+ input_arg {
+ name: "sorted_inputs"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "out_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "out_type"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "MakeIterator"
input_arg {
name: "dataset"
@@ -16780,6 +17304,134 @@ op {
is_commutative: true
}
op {
+ name: "MultiDeviceIterator"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "devices"
+ type: "list(string)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorFromStringHandle"
+ input_arg {
+ name: "string_handle"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorGetNextFromShard"
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "shard_num"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "incarnation_id"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorInit"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "max_buffer_size"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "incarnation_id"
+ type: DT_INT64
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorToStringHandle"
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "string_handle"
+ type: DT_STRING
+ }
+ is_stateful: true
+}
+op {
name: "Multinomial"
input_arg {
name: "logits"
@@ -19521,6 +20173,30 @@ op {
is_stateful: true
}
op {
+ name: "PrintV2"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ attr {
+ name: "output_stream"
+ type: "string"
+ default_value {
+ s: "stderr"
+ }
+ allowed_values {
+ list {
+ s: "stdout"
+ s: "stderr"
+ s: "log(info)"
+ s: "log(warning)"
+ s: "log(error)"
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "PriorityQueue"
output_arg {
name: "handle"
@@ -22608,6 +23284,59 @@ op {
is_stateful: true
}
op {
+ name: "ReduceDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "initial_state"
+ type_list_attr: "Tstate"
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Tstate"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "ReduceJoin"
input_arg {
name: "inputs"
@@ -28477,18 +29206,10 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_INT64
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
}
}
}
@@ -28512,18 +29233,10 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_INT64
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
}
}
}
@@ -28543,18 +29256,10 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_INT64
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
}
}
}
@@ -28578,18 +29283,10 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_INT64
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
}
}
}
@@ -32735,6 +33432,43 @@ op {
}
}
op {
+ name: "StringFormat"
+ input_arg {
+ name: "inputs"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "T"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "template"
+ type: "string"
+ default_value {
+ s: "%s"
+ }
+ }
+ attr {
+ name: "placeholder"
+ type: "string"
+ default_value {
+ s: "%s"
+ }
+ }
+ attr {
+ name: "summarize"
+ type: "int"
+ default_value {
+ i: 3
+ }
+ }
+}
+op {
name: "StringJoin"
input_arg {
name: "inputs"
@@ -32769,6 +33503,19 @@ op {
name: "output"
type: DT_INT32
}
+ attr {
+ name: "unit"
+ type: "string"
+ default_value {
+ s: "BYTE"
+ }
+ allowed_values {
+ list {
+ s: "BYTE"
+ s: "UTF8_CHAR"
+ }
+ }
+ }
}
op {
name: "StringSplit"
@@ -35370,6 +36117,17 @@ op {
}
}
op {
+ name: "UnicodeScript"
+ input_arg {
+ name: "input"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "output"
+ type: DT_INT32
+ }
+}
+op {
name: "UniformCandidateSampler"
input_arg {
name: "true_classes"
@@ -35954,6 +36712,38 @@ op {
is_stateful: true
}
op {
+ name: "UpperBound"
+ input_arg {
+ name: "sorted_inputs"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "out_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "out_type"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "VarHandleOp"
output_arg {
name: "resource"
@@ -36199,9 +36989,21 @@ op {
type: DT_VARIANT
}
input_arg {
- name: "window_size"
+ name: "size"
type: DT_INT64
}
+ input_arg {
+ name: "shift"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "stride"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "drop_remainder"
+ type: DT_BOOL
+ }
output_arg {
name: "handle"
type: DT_VARIANT
@@ -36438,6 +37240,62 @@ op {
is_stateful: true
}
op {
+ name: "Xdivy"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
+ name: "Xlogy"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "ZerosLike"
input_arg {
name: "x"
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index 26499540f1..adc9cd1486 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -19,6 +19,7 @@
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/errors.h"
using ::tensorflow::shape_inference::InferenceContext;
using ::tensorflow::shape_inference::ShapeAndType;
@@ -56,6 +57,36 @@ Status ReadVariableShapeFn(InferenceContext* c) {
return Status::OK();
}
+Status ReadVariablesShapeFn(InferenceContext* c) {
+ int n;
+ TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
+ DataTypeVector value_dtypes;
+ TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &value_dtypes));
+ if (n != value_dtypes.size()) {
+ return errors::InvalidArgument(
+ "Mismatched number of arguments to ReadVariablesOp");
+ }
+ for (int i = 0; i < n; ++i) {
+ ShapeAndType shape_and_type;
+ auto* handle_data = c->input_handle_shapes_and_types(i);
+ if (handle_data == nullptr || handle_data->empty()) {
+ shape_and_type.shape = c->UnknownShape();
+ shape_and_type.dtype = DT_INVALID;
+ } else {
+ shape_and_type = (*handle_data)[0];
+ if (shape_and_type.dtype != value_dtypes[i]) {
+ return errors::InvalidArgument(
+ "Trying to read variable with wrong dtype. "
+ "Expected ",
+ DataTypeString(shape_and_type.dtype), " got ",
+ DataTypeString(value_dtypes[i]));
+ }
+ }
+ c->set_output(i, shape_and_type.shape);
+ }
+ return Status::OK();
+}
+
} // namespace
REGISTER_OP("VarHandleOp")
@@ -79,12 +110,53 @@ REGISTER_OP("VarHandleOp")
return Status::OK();
});
+REGISTER_OP("_VarHandlesOp")
+ .Attr("containers: list(string)")
+ .Attr("shared_names: list(string)")
+ .Attr("N: int >= 0")
+ .Attr("dtypes: list(type)")
+ .Attr("shapes: list(shape)")
+ .Output("resources: N * resource")
+ .SetIsStateful()
+ .SetShapeFn([](InferenceContext* c) {
+ int n;
+ TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
+ DataTypeVector dtypes;
+ TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &dtypes));
+ std::vector<PartialTensorShape> shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
+ if (dtypes.size() != n) {
+ return errors::InvalidArgument("Mismatched number of dtypes (n=", n,
+ ", num dtypes=", dtypes.size(), ")");
+ }
+ if (shapes.size() != n) {
+ return errors::InvalidArgument("Mismatched number of shapes (n=", n,
+ ", num shapes=", shapes.size(), ")");
+ }
+ for (int i = 0; i < n; ++i) {
+ c->set_output(i, c->Scalar());
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shapes[i], &s));
+ c->set_output_handle_shapes_and_types(
+ i, std::vector<ShapeAndType>{{s, dtypes[i]}});
+ }
+
+ return Status::OK();
+ });
+
REGISTER_OP("ReadVariableOp")
.Input("resource: resource")
.Output("value: dtype")
.Attr("dtype: type")
.SetShapeFn(ReadVariableShapeFn);
+REGISTER_OP("_ReadVariablesOp")
+ .Attr("N: int >= 0")
+ .Input("resources: N * resource")
+ .Output("values: dtypes")
+ .Attr("dtypes: list(type)")
+ .SetShapeFn(ReadVariablesShapeFn);
+
Status ReadGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
*g = FunctionDefHelper::Define(
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index ef8b15dc8a..b4fbde54d9 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "absl/strings/str_split.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
@@ -102,6 +103,32 @@ REGISTER_OP("AsString")
.Attr("fill: string = ''")
.SetShapeFn(shape_inference::UnchangedShape);
+REGISTER_OP("StringFormat")
+ .Input("inputs: T")
+ .Output("output: string")
+ .Attr("T: list(type) >= 0")
+ .Attr("template: string = '%s'")
+ .Attr("placeholder: string = '%s'")
+ .Attr("summarize: int = 3")
+ .SetShapeFn([](InferenceContext* c) {
+ string template_;
+ string placeholder;
+ TF_RETURN_IF_ERROR(c->GetAttr("template", &template_));
+ TF_RETURN_IF_ERROR(c->GetAttr("placeholder", &placeholder));
+
+ std::vector<std::string> split_template;
+ split_template = absl::StrSplit(template_, placeholder);
+ int64 num_placeholders = split_template.size() - 1;
+ if (c->num_inputs() != num_placeholders) {
+ return errors::InvalidArgument(strings::StrCat(
+ "num placeholders in template and num inputs must match: ",
+ num_placeholders, " vs. ", c->num_inputs()));
+ }
+
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
REGISTER_OP("StringJoin")
.Input("inputs: N * string")
.Attr("N: int")
@@ -176,6 +203,7 @@ REGISTER_OP("StringStrip")
REGISTER_OP("StringLength")
.Input("input: string")
.Output("output: int32")
+ .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("EncodeBase64")
@@ -216,4 +244,9 @@ REGISTER_OP("Substr")
return shape_inference::BroadcastBinaryOpShapeFn(c);
});
+REGISTER_OP("UnicodeScript")
+ .Input("input: int32")
+ .Output("output: int32")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 83228fab6f..83ea8539ed 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -25,6 +25,7 @@ limitations under the License.
#ifdef _WIN32
#include <io.h> // for _mktemp
#endif
+#include "absl/base/macros.h"
#include "include/json/json.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -63,7 +64,7 @@ constexpr int kGetChildrenDefaultPageSize = 1000;
// The HTTP response code "308 Resume Incomplete".
constexpr uint64 HTTP_CODE_RESUME_INCOMPLETE = 308;
// The environment variable that overrides the size of the readahead buffer.
-// DEPRECATED. Use GCS_BLOCK_SIZE_MB instead.
+ABSL_DEPRECATED("Use GCS_BLOCK_SIZE_MB instead.")
constexpr char kReadaheadBufferSize[] = "GCS_READAHEAD_BUFFER_SIZE_BYTES";
// The environment variable that disables the GCS block cache for reads.
// This is the explicit alternative to setting BLOCK_SIZE or MAX_SIZE to 0, and
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index bb841aeab7..3b14757945 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -641,54 +641,41 @@ def tf_additional_lib_deps():
def tf_additional_core_deps():
return select({
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
+ "//tensorflow:android": [],
+ "//tensorflow:windows": [],
+ "//tensorflow:ios": [],
+ "//tensorflow:linux_s390x": [],
+ "//conditions:default": [
"//tensorflow/core/platform/cloud:gcs_file_system",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_hdfs_support_windows_override": [],
- "//tensorflow:with_hdfs_support_android_override": [],
- "//tensorflow:with_hdfs_support_ios_override": [],
- "//tensorflow:with_hdfs_support": [
- "//tensorflow/core/platform/hadoop:hadoop_file_system",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_aws_support_windows_override": [],
- "//tensorflow:with_aws_support_android_override": [],
- "//tensorflow:with_aws_support_ios_override": [],
- "//tensorflow:with_aws_support": [
"//tensorflow/core/platform/s3:s3_file_system",
+ "//tensorflow/core/platform/hadoop:hadoop_file_system",
],
- "//conditions:default": [],
})
# TODO(jart, jhseu): Delete when GCP is default on.
def tf_additional_cloud_op_deps():
return select({
- "//tensorflow:with_gcp_support_windows_override": [],
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
+ "//tensorflow:android": [],
+ "//tensorflow:windows": [],
+ "//tensorflow:ios": [],
+ "//tensorflow:linux_s390x": [],
+ "//conditions:default": [
"//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib",
"//tensorflow/contrib/cloud:gcs_config_ops_op_lib",
],
- "//conditions:default": [],
})
# TODO(jart, jhseu): Delete when GCP is default on.
def tf_additional_cloud_kernel_deps():
return select({
- "//tensorflow:with_gcp_support_windows_override": [],
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
+ "//tensorflow:android": [],
+ "//tensorflow:windows": [],
+ "//tensorflow:ios": [],
+ "//tensorflow:linux_s390x": [],
+ "//conditions:default": [
"//tensorflow/contrib/cloud/kernels:bigquery_reader_ops",
"//tensorflow/contrib/cloud/kernels:gcs_config_ops",
],
- "//conditions:default": [],
})
def tf_lib_proto_parsing_deps():
diff --git a/tensorflow/core/platform/default/build_config_root.bzl b/tensorflow/core/platform/default/build_config_root.bzl
index 3a012c23fd..37475feebe 100644
--- a/tensorflow/core/platform/default/build_config_root.bzl
+++ b/tensorflow/core/platform/default/build_config_root.bzl
@@ -3,64 +3,64 @@
# be separate to avoid cyclic references.
def tf_cuda_tests_tags():
- return ["requires-gpu"]
+ return ["requires-gpu", "local", "gpu"]
def tf_sycl_tests_tags():
- return ["requires-gpu"]
+ return ["requires-gpu", "local", "gpu"]
def tf_additional_plugin_deps():
- return select({
- str(Label("//tensorflow:with_xla_support")): [
- str(Label("//tensorflow/compiler/jit"))
- ],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_xla_support")): [
+ str(Label("//tensorflow/compiler/jit")),
+ ],
+ "//conditions:default": [],
+ })
def tf_additional_xla_deps_py():
- return []
+ return []
def tf_additional_grpc_deps_py():
- return []
+ return []
def tf_additional_license_deps():
- return select({
- str(Label("//tensorflow:with_xla_support")): ["@llvm//:LICENSE.TXT"],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_xla_support")): ["@llvm//:LICENSE.TXT"],
+ "//conditions:default": [],
+ })
def tf_additional_verbs_deps():
- return select({
- str(Label("//tensorflow:with_verbs_support")): [
- str(Label("//tensorflow/contrib/verbs:verbs_server_lib")),
- str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")),
- ],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_verbs_support")): [
+ str(Label("//tensorflow/contrib/verbs:verbs_server_lib")),
+ str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")),
+ ],
+ "//conditions:default": [],
+ })
def tf_additional_mpi_deps():
- return select({
- str(Label("//tensorflow:with_mpi_support")): [
- str(Label("//tensorflow/contrib/mpi:mpi_server_lib")),
- ],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_mpi_support")): [
+ str(Label("//tensorflow/contrib/mpi:mpi_server_lib")),
+ ],
+ "//conditions:default": [],
+ })
def tf_additional_gdr_deps():
- return select({
- str(Label("//tensorflow:with_gdr_support")): [
- str(Label("//tensorflow/contrib/gdr:gdr_server_lib")),
- ],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_gdr_support")): [
+ str(Label("//tensorflow/contrib/gdr:gdr_server_lib")),
+ ],
+ "//conditions:default": [],
+ })
-def if_static(extra_deps, otherwise=[]):
- return select({
- str(Label("//tensorflow:framework_shared_object")): otherwise,
- "//conditions:default": extra_deps,
- })
+def if_static(extra_deps, otherwise = []):
+ return select({
+ str(Label("//tensorflow:framework_shared_object")): otherwise,
+ "//conditions:default": extra_deps,
+ })
-def if_dynamic_kernels(extra_deps, otherwise=[]):
- return select({
- str(Label("//tensorflow:dynamic_loaded_kernels")): extra_deps,
- "//conditions:default": otherwise,
- })
+def if_dynamic_kernels(extra_deps, otherwise = []):
+ return select({
+ str(Label("//tensorflow:dynamic_loaded_kernels")): extra_deps,
+ "//conditions:default": otherwise,
+ })
diff --git a/tensorflow/core/platform/default/cord.h b/tensorflow/core/platform/default/cord.h
index 1ab682182c..5823374d1a 100644
--- a/tensorflow/core/platform/default/cord.h
+++ b/tensorflow/core/platform/default/cord.h
@@ -16,9 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
#define TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
-class Cord;
-namespace absl {
-using ::Cord;
-} // namespace absl
+// TODO(ebrevdo): Fill this in.
#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
diff --git a/tensorflow/core/platform/default/device_tracer.cc b/tensorflow/core/platform/default/device_tracer.cc
index 0389149469..83c65dbfa9 100644
--- a/tensorflow/core/platform/default/device_tracer.cc
+++ b/tensorflow/core/platform/default/device_tracer.cc
@@ -321,7 +321,12 @@ class DeviceTracerImpl : public DeviceTracer,
return nullptr;
}
- bool IsEnabled(bool is_expensive) const override {
+ bool IsEnabledForAnnotations() const override {
+ // We are always enabled for 'Annotations'.
+ return true;
+ }
+
+ bool IsEnabledForActivities(bool is_expensive) const override {
// We don't do anything with 'Activities' so we are never 'enabled'.
return false;
}
diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h
index 30059dc02e..156af6cdea 100644
--- a/tensorflow/core/platform/file_system.h
+++ b/tensorflow/core/platform/file_system.h
@@ -255,10 +255,13 @@ class WritableFile {
/// \brief Append 'data' to the file.
virtual Status Append(StringPiece data) = 0;
+ // TODO(ebrevdo): Remove this ifdef when absl is updated.
+#if defined(PLATFORM_GOOGLE)
// \brief Append 'data' to the file.
virtual Status Append(const absl::Cord& cord) {
return errors::Unimplemented("Append(absl::Cord) is not implemented");
}
+#endif
/// \brief Close the file.
///
diff --git a/tensorflow/core/platform/tracing.h b/tensorflow/core/platform/tracing.h
index 9974bbbb4e..aefbe64425 100644
--- a/tensorflow/core/platform/tracing.h
+++ b/tensorflow/core/platform/tracing.h
@@ -155,9 +155,12 @@ class TraceCollector {
StringPiece name_part1, StringPiece name_part2,
bool is_expensive) const = 0;
+ // Returns true if this annotation tracing is enabled for any op.
+ virtual bool IsEnabledForAnnotations() const = 0;
+
// Returns true if this activity handle tracking is enabled for an op of the
// given expensiveness.
- virtual bool IsEnabled(bool is_expensive) const = 0;
+ virtual bool IsEnabledForActivities(bool is_expensive) const = 0;
protected:
static string ConcatenateNames(StringPiece first, StringPiece second);
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index 625d5649e6..85cd02350a 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -68,7 +68,7 @@ message GPUOptions {
// after the process starts. Users are required to use vendor
// specific mechanisms (e.g., CUDA_VISIBLE_DEVICES) to control the
// physical to visible device mapping prior to invoking TensorFlow.
- // 2. In the code, the ids in this list are also called "CUDA GPU id"s,
+ // 2. In the code, the ids in this list are also called "platform GPU id"s,
// and the 'virtual' ids of GPU devices (i.e. the ids in the device
// name "/device:GPU:<id>") are also called "TF GPU id"s. Please
// refer to third_party/tensorflow/core/common_runtime/gpu/gpu_id.h
diff --git a/tensorflow/core/protobuf/replay_log.proto b/tensorflow/core/protobuf/replay_log.proto
new file mode 100644
index 0000000000..7644314fc9
--- /dev/null
+++ b/tensorflow/core/protobuf/replay_log.proto
@@ -0,0 +1,47 @@
+syntax = "proto3";
+
+option cc_enable_arenas = true;
+package tensorflow;
+
+import "tensorflow/core/framework/graph.proto";
+import "tensorflow/core/protobuf/cluster.proto";
+import "tensorflow/core/protobuf/master.proto";
+
+// Records the creation of a new replay session. We record the device listing
+// here to capture the state of the cluster.
+message NewReplaySession {
+ ListDevicesResponse devices = 1;
+ string session_handle = 2;
+}
+
+message ReplayOp {
+ double start_time_us = 31;
+ double end_time_us = 32;
+
+ oneof op {
+ CreateSessionRequest create_session = 1;
+ ExtendSessionRequest extend_session = 2;
+ PartialRunSetupRequest partial_run_setup = 3;
+ RunStepRequest run_step = 4;
+ CloseSessionRequest close_session = 5;
+ ListDevicesRequest list_devices = 6;
+ ResetRequest reset_request = 7;
+ MakeCallableRequest make_callable = 8;
+ RunCallableRequest run_callable = 9;
+ ReleaseCallableRequest release_callable = 10;
+ NewReplaySession new_replay_session = 11;
+ }
+
+ oneof response {
+ CreateSessionResponse create_session_response = 21;
+ ExtendSessionResponse extend_session_response = 22;
+ PartialRunSetupResponse partial_run_setup_response = 23;
+ RunStepResponse run_step_response = 24;
+ CloseSessionResponse close_session_response = 25;
+ ListDevicesResponse list_devices_response = 26;
+ ResetResponse reset_request_response = 27;
+ MakeCallableResponse make_callable_response = 28;
+ RunCallableResponse run_callable_response = 29;
+ ReleaseCallableResponse release_callable_response = 30;
+ }
+}
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index 07f984ceea..482178a540 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -75,6 +75,8 @@ message RewriterConfig {
// Try to allocate some independent Op outputs contiguously in order to
// merge or eliminate downstream Ops (off by default).
Toggle scoped_allocator_optimization = 15;
+ // Force small ops onto the CPU (default is ON).
+ Toggle pin_to_host_optimization = 18;
// Controls how many times we run the optimizers in meta optimizer (default
// is once).
@@ -141,8 +143,8 @@ message RewriterConfig {
// not configurable (in contrast to memory optimization passes through the
// meta-optimizer) and act only on manual op annotations.
//
- // Custom registered optimizers will be run after the base optimizers, in
- // the order that they are specified.
+ // Custom optimizers (see custom_optimizers) that are not part of this
+ // schedule will be run after - in the order that they were specified.
repeated string optimizers = 100;
// Message to describe custom graph optimizer and its parameters
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 4129c93af5..b043a69431 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -19,12 +19,12 @@ limitations under the License.
// TensorFlow uses semantic versioning, see http://semver.org/.
#define TF_MAJOR_VERSION 1
-#define TF_MINOR_VERSION 10
+#define TF_MINOR_VERSION 11
#define TF_PATCH_VERSION 0
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
// "-beta", "-rc", "-rc.1")
-#define TF_VERSION_SUFFIX ""
+#define TF_VERSION_SUFFIX "-rc1"
#define TF_STR_HELPER(x) #x
#define TF_STR(x) TF_STR_HELPER(x)
diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h
index 540adb58d4..f6f0408ccc 100644
--- a/tensorflow/core/util/cuda_kernel_helper.h
+++ b/tensorflow/core/util/cuda_kernel_helper.h
@@ -93,11 +93,11 @@ __device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXorSync(
}
namespace cuda_helper {
-template <typename IntType>
-__device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
- IntType* orig = first;
- IntType* it = nullptr;
- IntType step = 0;
+template <typename T, typename OutType = int32>
+__device__ OutType upper_bound(const T* first, OutType count, T val) {
+ const T* orig = first;
+ const T* it = nullptr;
+ OutType step = 0;
while (count > 0) {
it = first;
step = count / 2;
@@ -112,6 +112,27 @@ __device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
return first - orig;
}
+
+template <typename T, typename OutType = int32>
+__device__ OutType lower_bound(const T* first, OutType count, T val) {
+ const T* orig = first;
+ const T* it = nullptr;
+ OutType step = 0;
+ while (count > 0) {
+ it = first;
+ step = count / 2;
+ it += step;
+ if (*it < val) {
+ first = ++it;
+ count -= step + 1;
+ } else {
+ count = step;
+ }
+ }
+
+ return first - orig;
+}
+
} // namespace cuda_helper
} // namespace tensorflow
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 2f2705de92..04aaea4f89 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -34,9 +34,8 @@ limitations under the License.
#endif
#ifdef INTEL_MKL_ML_ONLY
-// Using pragma message since #warning doesn't work with all compilers
-#pragma message("Compiling for INTEL MKL ML only will be deprecated soon.")
-#pragma message("Please use MKL DNN (the default option for --config=mkl)")
+#error \
+ "Compiling for INTEL MKL ML only is no longer supported.Please use MKL DNN (the default option for --config=mkl)"
#endif
#ifdef INTEL_MKL_ML_ONLY
@@ -2040,8 +2039,8 @@ class MklPrimitiveFactory {
/// Fuction to check whether primitive memory optimization is enabled
static inline bool IsPrimitiveMemOptEnabled() {
bool is_primitive_mem_opt_enabled = true;
- TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE",
- true, &is_primitive_mem_opt_enabled));
+ TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE", true,
+ &is_primitive_mem_opt_enabled));
return is_primitive_mem_opt_enabled;
}
@@ -2096,9 +2095,8 @@ static inline memory::format get_desired_format(int channel,
fmt_desired = is_2d ? memory::format::nChw16c : memory::format::nCdhw16c;
} else if (port::TestCPUFeature(port::CPUFeature::AVX2) &&
(channel % 8) == 0) {
- fmt_desired = is_2d
- ? memory::format::nChw8c
- : memory::format::ncdhw; // no avx2 support for 3d yet.
+ fmt_desired = is_2d ? memory::format::nChw8c
+ : memory::format::ncdhw; // no avx2 support for 3d yet.
} else {
fmt_desired = is_2d ? memory::format::nchw : memory::format::ncdhw;
}
@@ -2211,7 +2209,7 @@ inline primitive FindOrCreateReorder(const memory* from, const memory* to) {
// utility function to determine if it is conv 1x1 and stride != 1
// for purpose of temporarily disabling primitive reuse
inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
- memory::dims strides) {
+ memory::dims strides) {
if (filter_dims.size() != 4 || strides.size() != 2) return false;
return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
diff --git a/tensorflow/core/util/port.cc b/tensorflow/core/util/port.cc
index c081ceae57..e01058dff6 100644
--- a/tensorflow/core/util/port.cc
+++ b/tensorflow/core/util/port.cc
@@ -38,10 +38,10 @@ bool CudaSupportsHalfMatMulAndConv() {
}
bool IsMklEnabled() {
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
return true;
#else
return false;
-#endif
+#endif // INTEL_MKL && ENABLE_MKL
}
} // end namespace tensorflow
diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h
index 0f04b65f60..b9ca8ab395 100644
--- a/tensorflow/core/util/sparse/sparse_tensor.h
+++ b/tensorflow/core/util/sparse/sparse_tensor.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/base/macros.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -95,21 +96,21 @@ class SparseTensor {
SparseTensor() : dims_(0) {}
- // DEPRECATED: use Create() functions instead of constructors directly.
+ ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape)
: SparseTensor(ix, vals, TensorShapeToVector(shape),
UndefinedOrder(TensorShapeToVector(shape))) {}
- // DEPRECATED: use Create() functions instead of constructors directly.
+ ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape)
: SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {}
- // DEPRECATED: use Create() functions instead of constructors directly.
+ ABSL_DEPRECATED("use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape,
const VarDimArray order)
: SparseTensor(ix, vals, TensorShapeToVector(shape), order) {}
- // DEPRECATED: use Create() functions instead of constructors directly.
+ ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
const VarDimArray order)
: ix_(ix),
@@ -237,9 +238,10 @@ class SparseTensor {
static Status Split(const SparseTensor& tensor, const int split_dim,
const int num_split, std::vector<SparseTensor>* result);
- // DEPRECATED: use the form of Split() that takes an output pointer and
- // returns a status instead.
template <typename T>
+ ABSL_DEPRECATED(
+ "Use the form of Split() that takes an output pointer and returns a "
+ "status instead.")
static std::vector<SparseTensor> Split(const SparseTensor& tensor,
const int split_dim,
const int num_split,
diff --git a/tensorflow/core/util/tensor_bundle/BUILD b/tensorflow/core/util/tensor_bundle/BUILD
index 648358606c..4d4db86df2 100644
--- a/tensorflow/core/util/tensor_bundle/BUILD
+++ b/tensorflow/core/util/tensor_bundle/BUILD
@@ -64,6 +64,7 @@ cc_library(
tf_cc_test(
name = "tensor_bundle_test",
srcs = ["tensor_bundle_test.cc"],
+ data = glob(["testdata/**"]),
deps = [
":tensor_bundle",
"//tensorflow/core:framework",
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
index ea8a259d1a..2dcb57a1f9 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
@@ -64,27 +64,36 @@ namespace {
// Reads "num_elements" string elements from file[offset, offset+size) into the
// length-N "destination". Discards the original content of "destination".
//
-// Checksums the string lengths (as restored uint32, not varint32 bytes) and
-// string bytes, and stores it into "actual_crc32c".
+// Checksums the string lengths (as restored uint32 or uint64, not varint64
+// bytes) and string bytes, and stores it into "actual_crc32c".
Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
size_t offset, size_t size, string* destination,
uint32* actual_crc32c) {
if (size == 0) return Status::OK();
CHECK_GT(size, 0);
- // Reads "num_elements" varint32's from "buffered_file".
+ // Reads "num_elements" varint64's from "buffered_file".
TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
- std::vector<uint32> string_lengths(num_elements);
+ std::vector<uint64> string_lengths(num_elements);
for (size_t i = 0; i < num_elements; ++i) {
- TF_RETURN_IF_ERROR(buffered_file->ReadVarint32(&string_lengths[i]));
+ TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_lengths[i]));
+ if (string_lengths[i] <= UINT32_MAX) {
+ // We need to do this because older checkpoints only used uint32s and we
+ // should still support them.
+ const uint32 elem_size_uint32 = static_cast<uint32>(string_lengths[i]);
+ *actual_crc32c = crc32c::Extend(
+ *actual_crc32c, reinterpret_cast<const char*>(&elem_size_uint32),
+ sizeof(uint32));
+ } else {
+ *actual_crc32c = crc32c::Extend(
+ *actual_crc32c, reinterpret_cast<const char*>(&string_lengths[i]),
+ sizeof(uint64));
+ }
}
if (offset + size < buffered_file->Tell()) {
return errors::DataLoss("String lengths longer than expected offset ",
offset + size);
}
- *actual_crc32c =
- crc32c::Value(reinterpret_cast<const char*>(string_lengths.data()),
- sizeof(uint32) * num_elements);
// Reads the length-checksum.
uint32 length_checksum = 0;
@@ -104,7 +113,7 @@ Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
// Reads the actual string bytes.
for (size_t i = 0; i < num_elements; ++i) {
- const uint32 string_length = string_lengths[i];
+ const uint64 string_length = string_lengths[i];
string* buffer = &destination[i];
buffer->resize(string_length);
@@ -218,8 +227,8 @@ Status WriteTensor(const Tensor& val, FileOutputBuffer* out,
Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
size_t* bytes_written, uint32* crc32c) {
// On-disk format:
- // [varint32 len0]..[varint32 lenL][4 byte cksum on lengths][string bytes]
- // Var "crc32c" checksums the string lengths (as uint32, not varint32 bytes),
+ // [varint64 len0]..[varint64 lenL][4 byte cksum on lengths][string bytes]
+ // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes),
// the length-checksum, and all the string bytes.
DCHECK_EQ(val.dtype(), DT_STRING);
const string* strings = GetStringBackingBuffer(val);
@@ -230,12 +239,21 @@ Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
*crc32c = 0;
for (int64 i = 0; i < val.NumElements(); ++i) {
const string* elem = &strings[i];
- DCHECK_EQ(elem->size(), static_cast<uint32>(elem->size()));
- const uint32 elem_size = static_cast<uint32>(elem->size());
-
- core::PutVarint32(&lengths, elem_size);
- *crc32c = crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&elem_size),
- sizeof(uint32));
+ DCHECK_EQ(elem->size(), static_cast<uint64>(elem->size()));
+ const uint64 elem_size = static_cast<uint64>(elem->size());
+
+ core::PutVarint64(&lengths, elem_size);
+ if (elem_size <= UINT32_MAX) {
+ // We need to do this because older checkpoints only used uint32s and we
+ // should still support them.
+ const uint32 elem_size_uint32 = static_cast<uint32>(elem_size);
+ *crc32c = crc32c::Extend(*crc32c,
+ reinterpret_cast<const char*>(&elem_size_uint32),
+ sizeof(uint32));
+ } else {
+ *crc32c = crc32c::Extend(
+ *crc32c, reinterpret_cast<const char*>(&elem_size), sizeof(uint64));
+ }
}
TF_RETURN_IF_ERROR(out->Append(lengths));
*bytes_written = lengths.size();
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
index 59c42baa06..9567e4750b 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
@@ -39,6 +39,11 @@ string Prefix(const string& prefix) {
return strings::StrCat(testing::TmpDir(), "/", prefix);
}
+string TestdataPrefix(const string& prefix) {
+ return strings::StrCat(testing::TensorFlowSrcRoot(),
+ "/core/util/tensor_bundle/testdata/", prefix);
+}
+
template <typename T>
Tensor Constant(T v, TensorShape shape) {
Tensor ret(DataTypeToEnum<T>::value, shape);
@@ -458,7 +463,26 @@ TEST(TensorBundleTest, NonStandardShapes) {
TestNonStandardShapes<qint8>();
}
+TEST(TensorBundleTest, StringTensorsOldFormat) {
+ // Test string tensor bundle made with previous version of code that use
+ // varint32s to store string lengths (we now use varint64s).
+ BundleReader reader(Env::Default(), TestdataPrefix("old_string_tensors/foo"));
+ TF_ASSERT_OK(reader.status());
+ EXPECT_EQ(AllTensorKeys(&reader),
+ std::vector<string>({"floats", "scalar", "string_tensor", "strs"}));
+
+ Expect<string>(&reader, "string_tensor", Tensor(DT_STRING, TensorShape({1})));
+ Expect<string>(&reader, "scalar", test::AsTensor<string>({"hello"}));
+ Expect<string>(
+ &reader, "strs",
+ test::AsTensor<string>({"hello", "", "x01", string(1 << 10, 'c')}));
+ Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
+}
+
TEST(TensorBundleTest, StringTensors) {
+ constexpr size_t kLongLength = static_cast<size_t>(UINT32_MAX) + 1;
+ Tensor long_string_tensor(DT_STRING, TensorShape({1}));
+
{
BundleWriter writer(Env::Default(), Prefix("foo"));
TF_EXPECT_OK(writer.Add("string_tensor",
@@ -467,6 +491,12 @@ TEST(TensorBundleTest, StringTensors) {
TF_EXPECT_OK(writer.Add(
"strs",
test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')})));
+
+ // Requires a 64-bit length.
+ string* backing_string = long_string_tensor.flat<string>().data();
+ backing_string->assign(kLongLength, 'd');
+ TF_EXPECT_OK(writer.Add("long_scalar", long_string_tensor));
+
// Mixes in some floats.
TF_EXPECT_OK(writer.Add("floats", Constant_2x3<float>(16.18)));
TF_ASSERT_OK(writer.Finish());
@@ -474,9 +504,9 @@ TEST(TensorBundleTest, StringTensors) {
{
BundleReader reader(Env::Default(), Prefix("foo"));
TF_ASSERT_OK(reader.status());
- EXPECT_EQ(
- AllTensorKeys(&reader),
- std::vector<string>({"floats", "scalar", "string_tensor", "strs"}));
+ EXPECT_EQ(AllTensorKeys(&reader),
+ std::vector<string>({"floats", "long_scalar", "scalar",
+ "string_tensor", "strs"}));
Expect<string>(&reader, "string_tensor",
Tensor(DT_STRING, TensorShape({1})));
@@ -484,7 +514,35 @@ TEST(TensorBundleTest, StringTensors) {
Expect<string>(
&reader, "strs",
test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')}));
+
Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
+
+ // We don't use the Expect function so we can re-use the
+ // `long_string_tensor` buffer for reading out long_scalar to keep memory
+ // usage reasonable.
+ EXPECT_TRUE(reader.Contains("long_scalar"));
+ DataType dtype;
+ TensorShape shape;
+ TF_ASSERT_OK(reader.LookupDtypeAndShape("long_scalar", &dtype, &shape));
+ EXPECT_EQ(DT_STRING, dtype);
+ EXPECT_EQ(TensorShape({1}), shape);
+
+ // Zero-out the string so that we can be sure the new one is read in.
+ string* backing_string = long_string_tensor.flat<string>().data();
+ backing_string->assign("");
+
+ // Read long_scalar and check it contains kLongLength 'd's.
+ TF_ASSERT_OK(reader.Lookup("long_scalar", &long_string_tensor));
+ ASSERT_EQ(backing_string, long_string_tensor.flat<string>().data());
+ EXPECT_EQ(kLongLength, backing_string->length());
+ for (char c : *backing_string) {
+ // Not using ASSERT_EQ('d', c) because this way is twice as fast due to
+ // compiler optimizations.
+ if (c != 'd') {
+ FAIL() << "long_scalar is not full of 'd's as expected.";
+ break;
+ }
+ }
}
}
diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README
new file mode 100644
index 0000000000..428d3ef79e
--- /dev/null
+++ b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README
@@ -0,0 +1,3 @@
+This tensor bundle was generated from cl/214343133, before string tensor
+lengths were written as varint64s. This is here to check backwards
+compatibility between the new code and old checkpoints.
diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001 b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001
new file mode 100644
index 0000000000..23b488e5fe
--- /dev/null
+++ b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001
Binary files differ
diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index
new file mode 100644
index 0000000000..a22a69e6e1
--- /dev/null
+++ b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index
Binary files differ
diff --git a/tensorflow/core/util/work_sharder.cc b/tensorflow/core/util/work_sharder.cc
index f4bd2950e9..74f0713a61 100644
--- a/tensorflow/core/util/work_sharder.cc
+++ b/tensorflow/core/util/work_sharder.cc
@@ -50,6 +50,8 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
max_parallelism);
}
+// DEPRECATED: Prefer threadpool->TransformRangeConcurrently, which allows you
+// to directly specify the shard size.
void Sharder::Do(int64 total, int64 cost_per_unit, const Work& work,
const Runner& runner, int max_parallelism) {
cost_per_unit = std::max(int64{1}, cost_per_unit);
diff --git a/tensorflow/core/util/work_sharder.h b/tensorflow/core/util/work_sharder.h
index b12c31c1ae..9db85a54c6 100644
--- a/tensorflow/core/util/work_sharder.h
+++ b/tensorflow/core/util/work_sharder.h
@@ -23,6 +23,9 @@ limitations under the License.
namespace tensorflow {
+// DEPRECATED: Prefer threadpool->TransformRangeConcurrently, which allows you
+// to directly specify the shard size. Use this function only if you want to
+// manually cap parallelism.
// Shards the "total" unit of work assuming each unit of work having
// roughly "cost_per_unit". Each unit of work is indexed 0, 1, ...,
// total - 1. Each shard contains 1 or more units of work and the
diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD
index f327b645f5..f5f0d7c3c8 100644
--- a/tensorflow/examples/android/BUILD
+++ b/tensorflow/examples/android/BUILD
@@ -68,6 +68,7 @@ android_binary(
srcs = glob([
"src/**/*.java",
]),
+ aapt_version = "aapt",
# Package assets from assets dir as well as all model targets. Remove undesired models
# (and corresponding Activities in source) to reduce APK size.
assets = [
diff --git a/tensorflow/examples/autograph/integration_tests/errors_test.py b/tensorflow/examples/autograph/integration_tests/errors_test.py
index 69e5936832..9c10dad9aa 100644
--- a/tensorflow/examples/autograph/integration_tests/errors_test.py
+++ b/tensorflow/examples/autograph/integration_tests/errors_test.py
@@ -92,7 +92,7 @@ class ErrorsTest(tf.test.TestCase):
compiled_fn = ag.to_graph(test_fn)
with self.assertRaises(ag.TfRuntimeError) as error:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = compiled_fn(tf.constant([4, 8]))
with ag.improved_errors(compiled_fn):
sess.run(x)
@@ -134,7 +134,7 @@ class ErrorsTest(tf.test.TestCase):
# frame with "g" as the function name but because we don't yet add
# try/except blocks to inner functions the name is "tf__g".
with self.assertRaises(ag.TfRuntimeError) as error:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = compiled_fn(tf.constant([4, 8]))
with ag.improved_errors(compiled_fn):
sess.run(x)
diff --git a/tensorflow/examples/learn/text_classification_character_cnn.py b/tensorflow/examples/learn/text_classification_character_cnn.py
index afda170e2a..b8506fa8a4 100644
--- a/tensorflow/examples/learn/text_classification_character_cnn.py
+++ b/tensorflow/examples/learn/text_classification_character_cnn.py
@@ -74,7 +74,7 @@ def char_cnn_model(features, labels, mode):
kernel_size=FILTER_SHAPE2,
padding='VALID')
# Max across each filter to get useful features for classification.
- pool2 = tf.squeeze(tf.reduce_max(conv2, 1), squeeze_dims=[1])
+ pool2 = tf.squeeze(tf.reduce_max(conv2, 1), axis=[1])
# Apply regular WX + B and classification.
logits = tf.layers.dense(pool2, MAX_LABEL, activation=None)
diff --git a/tensorflow/examples/tutorials/mnist/BUILD b/tensorflow/examples/tutorials/mnist/BUILD
index d4070fdd1e..99da44d6d5 100644
--- a/tensorflow/examples/tutorials/mnist/BUILD
+++ b/tensorflow/examples/tutorials/mnist/BUILD
@@ -84,6 +84,18 @@ py_binary(
)
py_binary(
+ name = "mnist_softmax_xla",
+ srcs = [
+ "mnist_softmax_xla.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":input_data",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_binary(
name = "mnist_deep",
srcs = [
"mnist_deep.py",
diff --git a/tensorflow/go/README.md b/tensorflow/go/README.md
index 288a32530a..3989f9b25a 100644
--- a/tensorflow/go/README.md
+++ b/tensorflow/go/README.md
@@ -10,7 +10,7 @@ Construct and execute TensorFlow graphs in Go.
## Quickstart
-Refer to [Installing TensorFlow for Go](https://www.tensorflow.org/install/install_go)
+Refer to [Installing TensorFlow for Go](https://www.tensorflow.org/install/lang_go)
## Building the TensorFlow C library from source
@@ -23,9 +23,7 @@ from source.
- [bazel](https://www.bazel.build/versions/master/docs/install.html)
- Environment to build TensorFlow from source code
- ([Linux](https://www.tensorflow.org/install/install_sources#PrepareLinux)
- or [OS
- X](https://www.tensorflow.org/install/install_sources#PrepareMac)).
+ ([Linux of macOS](https://www.tensorflow.org/install/source)).
If you don't need GPU support, then try the following:
```sh
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 322b35dd91..b4d4db3e4d 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -332,7 +332,7 @@ func FakeQuantWithMinMaxArgs(scope *Scope, inputs tf.Output, optional ...FakeQua
// Creates a new tensor by applying sparse `updates` to individual values or
// slices within a tensor (initially zero for numeric, empty for string) of
// the given `shape` according to indices. This operator is the inverse of the
-// @{tf.gather_nd} operator which extracts values or slices from a given tensor.
+// `tf.gather_nd` operator which extracts values or slices from a given tensor.
//
// If `indices` contains duplicates, then their updates are accumulated (summed).
//
@@ -1473,7 +1473,7 @@ type StridedSliceAttr func(optionalAttr)
//
// value: a bitmask where a bit i being 1 means to ignore the begin
// value and instead use the largest interval possible. At runtime
-// begin[i] will be replaced with `[0, n-1) if `stride[i] > 0` or
+// begin[i] will be replaced with `[0, n-1)` if `stride[i] > 0` or
// `[-1, n-1]` if `stride[i] < 0`
// If not specified, defaults to 0
func StridedSliceBeginMask(value int64) StridedSliceAttr {
@@ -1856,6 +1856,32 @@ func ReverseSequence(scope *Scope, input tf.Output, seq_lengths tf.Output, seq_d
return op.Output(0)
}
+// Ensures that the tensor's shape matches the expected shape.
+//
+// Raises an error if the input tensor's shape does not match the specified shape.
+// Returns the input tensor otherwise.
+//
+// Arguments:
+// input: A tensor, whose shape is to be validated.
+// shape: The expected (possibly partially specified) shape of the input tensor.
+//
+// Returns A tensor with the same shape and contents as the input tensor or value.
+func EnsureShape(scope *Scope, input tf.Output, shape tf.Shape) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"shape": shape}
+ opspec := tf.OpSpec{
+ Type: "EnsureShape",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// UniqueWithCountsV2Attr is an optional argument to UniqueWithCountsV2.
type UniqueWithCountsV2Attr func(optionalAttr)
@@ -2259,7 +2285,7 @@ func CheckNumerics(scope *Scope, tensor tf.Output, message string) (output tf.Ou
//
// output[\\(i_0, ..., i_{K-2}\\)] = params[indices[\\(i_0, ..., i_{K-2}\\)]]
//
-// Whereas in @{tf.gather} `indices` defines slices into the first
+// Whereas in `tf.gather` `indices` defines slices into the first
// dimension of `params`, in `tf.gather_nd`, `indices` defines slices into the
// first `N` dimensions of `params`, where `N = indices.shape[-1]`.
//
@@ -2356,6 +2382,8 @@ func CheckNumerics(scope *Scope, tensor tf.Output, message string) (output tf.Ou
// output = [['b0', 'b1'], ['d0', 'c1']]
// ```
//
+// See also `tf.gather` and `tf.batch_gather`.
+//
// Arguments:
// params: The tensor from which to gather values.
// indices: Index tensor.
@@ -2433,120 +2461,102 @@ func Gather(scope *Scope, params tf.Output, indices tf.Output, optional ...Gathe
return op.Output(0)
}
-// Creates a tensor filled with a scalar value.
+// LowerBoundAttr is an optional argument to LowerBound.
+type LowerBoundAttr func(optionalAttr)
+
+// LowerBoundOutType sets the optional out_type attribute to value.
+// If not specified, defaults to DT_INT32
+func LowerBoundOutType(value tf.DataType) LowerBoundAttr {
+ return func(m optionalAttr) {
+ m["out_type"] = value
+ }
+}
+
+// Applies lower_bound(sorted_search_values, values) along each row.
//
-// This operation creates a tensor of shape `dims` and fills it with `value`.
+// Each set of rows with the same index in (sorted_inputs, values) is treated
+// independently. The resulting row is the equivalent of calling
+// `np.searchsorted(sorted_inputs, values, side='left')`.
//
-// For example:
+// The result is not a global index to the entire
+// `Tensor`, but rather just the index in the last dimension.
//
-// ```
-// # Output tensor has shape [2, 3].
-// fill([2, 3], 9) ==> [[9, 9, 9]
-// [9, 9, 9]]
-// ```
+// A 2-D example:
+// sorted_sequence = [[0, 3, 9, 9, 10],
+// [1, 2, 3, 4, 5]]
+// values = [[2, 4, 9],
+// [0, 2, 6]]
+//
+// result = LowerBound(sorted_sequence, values)
+//
+// result == [[1, 2, 2],
+// [0, 1, 5]]
//
// Arguments:
-// dims: 1-D. Represents the shape of the output tensor.
-// value: 0-D (scalar). Value to fill the returned tensor.
+// sorted_inputs: 2-D Tensor where each row is ordered.
+// values: 2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains
+// the values that will be searched for in `sorted_search_values`.
//
-// @compatibility(numpy)
-// Equivalent to np.full
-// @end_compatibility
-func Fill(scope *Scope, dims tf.Output, value tf.Output) (output tf.Output) {
+// Returns A `Tensor` with the same shape as `values`. It contains the first scalar index
+// into the last dimension where values can be inserted without changing the
+// ordered property.
+func LowerBound(scope *Scope, sorted_inputs tf.Output, values tf.Output, optional ...LowerBoundAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "Fill",
+ Type: "LowerBound",
Input: []tf.Input{
- dims, value,
+ sorted_inputs, values,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// EditDistanceAttr is an optional argument to EditDistance.
-type EditDistanceAttr func(optionalAttr)
-
-// EditDistanceNormalize sets the optional normalize attribute to value.
-//
-// value: boolean (if true, edit distances are normalized by length of truth).
-//
-// The output is:
-// If not specified, defaults to true
-func EditDistanceNormalize(value bool) EditDistanceAttr {
- return func(m optionalAttr) {
- m["normalize"] = value
- }
-}
-
-// Computes the (possibly normalized) Levenshtein Edit Distance.
-//
-// The inputs are variable-length sequences provided by SparseTensors
-// (hypothesis_indices, hypothesis_values, hypothesis_shape)
-// and
-// (truth_indices, truth_values, truth_shape).
+// Creates a tensor filled with a scalar value.
//
-// The inputs are:
+// This operation creates a tensor of shape `dims` and fills it with `value`.
//
-// Arguments:
-// hypothesis_indices: The indices of the hypothesis list SparseTensor.
-// This is an N x R int64 matrix.
-// hypothesis_values: The values of the hypothesis list SparseTensor.
-// This is an N-length vector.
-// hypothesis_shape: The shape of the hypothesis list SparseTensor.
-// This is an R-length vector.
-// truth_indices: The indices of the truth list SparseTensor.
-// This is an M x R int64 matrix.
-// truth_values: The values of the truth list SparseTensor.
-// This is an M-length vector.
-// truth_shape: truth indices, vector.
+// For example:
//
-// Returns A dense float tensor with rank R - 1.
+// ```
+// # Output tensor has shape [2, 3].
+// fill([2, 3], 9) ==> [[9, 9, 9]
+// [9, 9, 9]]
+// ```
//
-// For the example input:
+// `tf.fill` differs from `tf.constant` in a few ways:
//
-// // hypothesis represents a 2x1 matrix with variable-length values:
-// // (0,0) = ["a"]
-// // (1,0) = ["b"]
-// hypothesis_indices = [[0, 0, 0],
-// [1, 0, 0]]
-// hypothesis_values = ["a", "b"]
-// hypothesis_shape = [2, 1, 1]
+// * `tf.fill` only supports scalar contents, whereas `tf.constant` supports
+// Tensor values.
+// * `tf.fill` creates an Op in the computation graph that constructs the actual
+// Tensor value at runtime. This is in contrast to `tf.constant` which embeds
+// the entire Tensor into the graph with a `Const` node.
+// * Because `tf.fill` evaluates at graph runtime, it supports dynamic shapes
+// based on other runtime Tensors, unlike `tf.constant`.
//
-// // truth represents a 2x2 matrix with variable-length values:
-// // (0,0) = []
-// // (0,1) = ["a"]
-// // (1,0) = ["b", "c"]
-// // (1,1) = ["a"]
-// truth_indices = [[0, 1, 0],
-// [1, 0, 0],
-// [1, 0, 1],
-// [1, 1, 0]]
-// truth_values = ["a", "b", "c", "a"]
-// truth_shape = [2, 2, 2]
-// normalize = true
-//
-// The output will be:
+// Arguments:
+// dims: 1-D. Represents the shape of the output tensor.
+// value: 0-D (scalar). Value to fill the returned tensor.
//
-// // output is a 2x2 matrix with edit distances normalized by truth lengths.
-// output = [[inf, 1.0], // (0,0): no truth, (0,1): no hypothesis
-// [0.5, 1.0]] // (1,0): addition, (1,1): no hypothesis
-func EditDistance(scope *Scope, hypothesis_indices tf.Output, hypothesis_values tf.Output, hypothesis_shape tf.Output, truth_indices tf.Output, truth_values tf.Output, truth_shape tf.Output, optional ...EditDistanceAttr) (output tf.Output) {
+// @compatibility(numpy)
+// Equivalent to np.full
+// @end_compatibility
+func Fill(scope *Scope, dims tf.Output, value tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
opspec := tf.OpSpec{
- Type: "EditDistance",
+ Type: "Fill",
Input: []tf.Input{
- hypothesis_indices, hypothesis_values, hypothesis_shape, truth_indices, truth_values, truth_shape,
+ dims, value,
},
- Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
@@ -2858,6 +2868,25 @@ func GuaranteeConst(scope *Scope, input tf.Output) (output tf.Output) {
return op.Output(0)
}
+// Returns a constant tensor on the host. Only for writing C++ tests.
+//
+// Arguments:
+// value: Attr `value` is the tensor to return.
+//
+func HostConst(scope *Scope, value tf.Tensor, dtype tf.DataType) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"value": value, "dtype": dtype}
+ opspec := tf.OpSpec{
+ Type: "HostConst",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Splits a tensor into `num_split` tensors along one dimension.
//
// Arguments:
@@ -3377,6 +3406,204 @@ func PopulationCount(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
+// Bucketize each feature based on bucket boundaries.
+//
+// An op that returns a list of float tensors, where each tensor represents the
+// bucketized values for a single feature.
+//
+// Arguments:
+// float_values: float; List of Rank 2 Tensor each containing float values for a single feature.
+// bucket_boundaries: float; List of Rank 1 Tensors each containing the bucket boundaries for a single
+// feature.
+//
+// Returns int; List of Rank 2 Tensors each containing the bucketized values for a single feature.
+func BoostedTreesBucketize(scope *Scope, float_values []tf.Output, bucket_boundaries []tf.Output) (buckets []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesBucketize",
+ Input: []tf.Input{
+ tf.OutputList(float_values), tf.OutputList(bucket_boundaries),
+ },
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if buckets, idx, err = makeOutputList(op, idx, "buckets"); err != nil {
+ scope.UpdateErr("BoostedTreesBucketize", err)
+ return
+ }
+ return buckets
+}
+
+// BoostedTreesQuantileStreamResourceFlushAttr is an optional argument to BoostedTreesQuantileStreamResourceFlush.
+type BoostedTreesQuantileStreamResourceFlushAttr func(optionalAttr)
+
+// BoostedTreesQuantileStreamResourceFlushGenerateQuantiles sets the optional generate_quantiles attribute to value.
+//
+// value: bool; If True, the output will be the num_quantiles for each stream where the ith
+// entry is the ith quantile of the input with an approximation error of epsilon.
+// Duplicate values may be present.
+// If False, the output will be the points in the histogram that we got which roughly
+// translates to 1/epsilon boundaries and without any duplicates.
+// Default to False.
+// If not specified, defaults to false
+func BoostedTreesQuantileStreamResourceFlushGenerateQuantiles(value bool) BoostedTreesQuantileStreamResourceFlushAttr {
+ return func(m optionalAttr) {
+ m["generate_quantiles"] = value
+ }
+}
+
+// Flush the summaries for a quantile stream resource.
+//
+// An op that flushes the summaries for a quantile stream resource.
+//
+// Arguments:
+// quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource.
+// num_buckets: int; approximate number of buckets unless using generate_quantiles.
+//
+// Returns the created operation.
+func BoostedTreesQuantileStreamResourceFlush(scope *Scope, quantile_stream_resource_handle tf.Output, num_buckets tf.Output, optional ...BoostedTreesQuantileStreamResourceFlushAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesQuantileStreamResourceFlush",
+ Input: []tf.Input{
+ quantile_stream_resource_handle, num_buckets,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Add the quantile summaries to each quantile stream resource.
+//
+// An op that adds a list of quantile summaries to a quantile stream resource. Each
+// summary Tensor is rank 2, containing summaries (value, weight, min_rank, max_rank)
+// for a single feature.
+//
+// Arguments:
+// quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource.
+// summaries: string; List of Rank 2 Tensor each containing the summaries for a single feature.
+//
+// Returns the created operation.
+func BoostedTreesQuantileStreamResourceAddSummaries(scope *Scope, quantile_stream_resource_handle tf.Output, summaries []tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesQuantileStreamResourceAddSummaries",
+ Input: []tf.Input{
+ quantile_stream_resource_handle, tf.OutputList(summaries),
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Makes the summary of quantiles for the batch.
+//
+// An op that takes a list of tensors and outputs the quantile summaries for each tensor.
+//
+// Arguments:
+// float_values: float; List of Rank 2 Tensors each containing values for a single feature.
+// example_weights: float; Rank 1 Tensor with weights per instance.
+// epsilon: float; The required maximum approximation error.
+//
+// Returns float; List of Rank 2 Tensors each containing the quantile summary (value, weight,
+// min_rank, max_rank) of a single feature.
+func BoostedTreesMakeQuantileSummaries(scope *Scope, float_values []tf.Output, example_weights tf.Output, epsilon tf.Output) (summaries []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesMakeQuantileSummaries",
+ Input: []tf.Input{
+ tf.OutputList(float_values), example_weights, epsilon,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if summaries, idx, err = makeOutputList(op, idx, "summaries"); err != nil {
+ scope.UpdateErr("BoostedTreesMakeQuantileSummaries", err)
+ return
+ }
+ return summaries
+}
+
+// BoostedTreesCreateQuantileStreamResourceAttr is an optional argument to BoostedTreesCreateQuantileStreamResource.
+type BoostedTreesCreateQuantileStreamResourceAttr func(optionalAttr)
+
+// BoostedTreesCreateQuantileStreamResourceMaxElements sets the optional max_elements attribute to value.
+//
+// value: int; The maximum number of data points that can be fed to the stream.
+// If not specified, defaults to 1099511627776
+func BoostedTreesCreateQuantileStreamResourceMaxElements(value int64) BoostedTreesCreateQuantileStreamResourceAttr {
+ return func(m optionalAttr) {
+ m["max_elements"] = value
+ }
+}
+
+// Create the Resource for Quantile Streams.
+//
+// Arguments:
+// quantile_stream_resource_handle: resource; Handle to quantile stream resource.
+// epsilon: float; The required approximation error of the stream resource.
+// num_streams: int; The number of streams managed by the resource that shares the same epsilon.
+//
+// Returns the created operation.
+func BoostedTreesCreateQuantileStreamResource(scope *Scope, quantile_stream_resource_handle tf.Output, epsilon tf.Output, num_streams tf.Output, optional ...BoostedTreesCreateQuantileStreamResourceAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesCreateQuantileStreamResource",
+ Input: []tf.Input{
+ quantile_stream_resource_handle, epsilon, num_streams,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Checks whether a quantile stream has been initialized.
+//
+// An Op that checks if quantile stream resource is initialized.
+//
+// Arguments:
+// quantile_stream_resource_handle: resource; The reference to quantile stream resource handle.
+//
+// Returns bool; True if the resource is initialized, False otherwise.
+func IsBoostedTreesQuantileStreamResourceInitialized(scope *Scope, quantile_stream_resource_handle tf.Output) (is_initialized tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "IsBoostedTreesQuantileStreamResourceInitialized",
+ Input: []tf.Input{
+ quantile_stream_resource_handle,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Calculates the prior from the training data (the bias) and fills in the first node with the logits' prior. Returns a boolean indicating whether to continue centering.
//
// Arguments:
@@ -3486,276 +3713,683 @@ func BoostedTreesExampleDebugOutputs(scope *Scope, tree_ensemble_handle tf.Outpu
return op.Output(0)
}
-// Computes the sum along sparse segments of a tensor.
-//
-// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is
-// misisng, the `output` tensor at that position will be zeroed.
+// Makes the summary of accumulated stats for the batch.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// The summary stats contains gradients and hessians accumulated into the corresponding node and bucket for each example.
//
-// For example:
+// Arguments:
+// node_ids: int32 Rank 1 Tensor containing node ids, which each example falls into for the requested layer.
+// gradients: float32; Rank 2 Tensor (shape=[#examples, 1]) for gradients.
+// hessians: float32; Rank 2 Tensor (shape=[#examples, 1]) for hessians.
+// bucketized_features_list: int32 list of Rank 1 Tensors, each containing the bucketized feature (for each feature column).
+// max_splits: int; the maximum number of splits possible in the whole tree.
+// num_buckets: int; equals to the maximum possible value of bucketized feature.
//
-// ```python
-// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+// Returns output Rank 4 Tensor (shape=[#features, #splits, #buckets, 2]) containing accumulated stats put into the corresponding node and bucket. The first index of 4th dimension refers to gradients, and the second to hessians.
+func BoostedTreesMakeStatsSummary(scope *Scope, node_ids tf.Output, gradients tf.Output, hessians tf.Output, bucketized_features_list []tf.Output, max_splits int64, num_buckets int64) (stats_summary tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"max_splits": max_splits, "num_buckets": num_buckets}
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesMakeStatsSummary",
+ Input: []tf.Input{
+ node_ids, gradients, hessians, tf.OutputList(bucketized_features_list),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Creates a tree ensemble model and returns a handle to it.
//
-// tf.sparse_segment_sum_with_num_segments(
-// c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3)
-// # => [[0 0 0 0]
-// # [0 0 0 0]
-// # [0 0 0 0]]
+// Arguments:
+// tree_ensemble_handle: Handle to the tree ensemble resource to be created.
+// stamp_token: Token to use as the initial value of the resource stamp.
+// tree_ensemble_serialized: Serialized proto of the tree ensemble.
//
-// tf.sparse_segment_sum_with_num_segments(c,
-// tf.constant([0, 1]),
-// tf.constant([0, 2],
-// num_segments=4))
-// # => [[ 1 2 3 4]
-// # [ 0 0 0 0]
-// # [-1 -2 -3 -4]
-// # [ 0 0 0 0]]
-// ```
+// Returns the created operation.
+func BoostedTreesCreateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesCreateEnsemble",
+ Input: []tf.Input{
+ tree_ensemble_handle, stamp_token, tree_ensemble_serialized,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Checks whether a tree ensemble has been initialized.
//
// Arguments:
+// tree_ensemble_handle: Handle to the tree ensemble resouce.
//
-// indices: A 1-D tensor. Has same rank as `segment_ids`.
-// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
-// num_segments: Should equal the number of distinct segment IDs.
-//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `num_segments`.
-func SparseSegmentSumWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+// Returns output boolean on whether it is initialized or not.
+func IsBoostedTreesEnsembleInitialized(scope *Scope, tree_ensemble_handle tf.Output) (is_initialized tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "SparseSegmentSumWithNumSegments",
+ Type: "IsBoostedTreesEnsembleInitialized",
Input: []tf.Input{
- data, indices, segment_ids, num_segments,
+ tree_ensemble_handle,
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// PreventGradientAttr is an optional argument to PreventGradient.
-type PreventGradientAttr func(optionalAttr)
+// BoostedTreesEnsembleResourceHandleOpAttr is an optional argument to BoostedTreesEnsembleResourceHandleOp.
+type BoostedTreesEnsembleResourceHandleOpAttr func(optionalAttr)
-// PreventGradientMessage sets the optional message attribute to value.
-//
-// value: Will be printed in the error when anyone tries to differentiate
-// this operation.
+// BoostedTreesEnsembleResourceHandleOpContainer sets the optional container attribute to value.
// If not specified, defaults to ""
-func PreventGradientMessage(value string) PreventGradientAttr {
+func BoostedTreesEnsembleResourceHandleOpContainer(value string) BoostedTreesEnsembleResourceHandleOpAttr {
return func(m optionalAttr) {
- m["message"] = value
+ m["container"] = value
}
}
-// An identity op that triggers an error if a gradient is requested.
+// BoostedTreesEnsembleResourceHandleOpSharedName sets the optional shared_name attribute to value.
+// If not specified, defaults to ""
+func BoostedTreesEnsembleResourceHandleOpSharedName(value string) BoostedTreesEnsembleResourceHandleOpAttr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// Creates a handle to a BoostedTreesEnsembleResource
+func BoostedTreesEnsembleResourceHandleOp(scope *Scope, optional ...BoostedTreesEnsembleResourceHandleOpAttr) (resource tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesEnsembleResourceHandleOp",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ComputeAccidentalHitsAttr is an optional argument to ComputeAccidentalHits.
+type ComputeAccidentalHitsAttr func(optionalAttr)
+
+// ComputeAccidentalHitsSeed sets the optional seed attribute to value.
//
-// When executed in a graph, this op outputs its input tensor as-is.
+// value: If either seed or seed2 are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func ComputeAccidentalHitsSeed(value int64) ComputeAccidentalHitsAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// ComputeAccidentalHitsSeed2 sets the optional seed2 attribute to value.
//
-// When building ops to compute gradients, the TensorFlow gradient system
-// will return an error when trying to lookup the gradient of this op,
-// because no gradient must ever be registered for this function. This
-// op exists to prevent subtle bugs from silently returning unimplemented
-// gradients in some corner cases.
+// value: An second seed to avoid seed collision.
+// If not specified, defaults to 0
+func ComputeAccidentalHitsSeed2(value int64) ComputeAccidentalHitsAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Computes the ids of the positions in sampled_candidates that match true_labels.
+//
+// When doing log-odds NCE, the result of this op should be passed through a
+// SparseToDense op, then added to the logits of the sampled candidates. This has
+// the effect of 'removing' the sampled labels that match the true labels by
+// making the classifier sure that they are sampled labels.
//
// Arguments:
-// input: any tensor.
+// true_classes: The true_classes output of UnpackSparseLabels.
+// sampled_candidates: The sampled_candidates output of CandidateSampler.
+// num_true: Number of true labels per context.
//
-// Returns the same input tensor.
-func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientAttr) (output tf.Output) {
+// Returns A vector of indices corresponding to rows of true_candidates.A vector of IDs of positions in sampled_candidates that match a true_label
+// for the row with the corresponding index in indices.A vector of the same length as indices and ids, in which each element
+// is -FLOAT_MAX.
+func ComputeAccidentalHits(scope *Scope, true_classes tf.Output, sampled_candidates tf.Output, num_true int64, optional ...ComputeAccidentalHitsAttr) (indices tf.Output, ids tf.Output, weights tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{}
+ attrs := map[string]interface{}{"num_true": num_true}
for _, a := range optional {
a(attrs)
}
opspec := tf.OpSpec{
- Type: "PreventGradient",
+ Type: "ComputeAccidentalHits",
Input: []tf.Input{
- input,
+ true_classes, sampled_candidates,
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0)
+ return op.Output(0), op.Output(1), op.Output(2)
}
-// Computes asin of x element-wise.
-func Asin(scope *Scope, x tf.Output) (y tf.Output) {
+// FixedUnigramCandidateSamplerAttr is an optional argument to FixedUnigramCandidateSampler.
+type FixedUnigramCandidateSamplerAttr func(optionalAttr)
+
+// FixedUnigramCandidateSamplerVocabFile sets the optional vocab_file attribute to value.
+//
+// value: Each valid line in this file (which should have a CSV-like format)
+// corresponds to a valid word ID. IDs are in sequential order, starting from
+// num_reserved_ids. The last entry in each line is expected to be a value
+// corresponding to the count or relative probability. Exactly one of vocab_file
+// and unigrams needs to be passed to this op.
+// If not specified, defaults to ""
+func FixedUnigramCandidateSamplerVocabFile(value string) FixedUnigramCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["vocab_file"] = value
+ }
+}
+
+// FixedUnigramCandidateSamplerDistortion sets the optional distortion attribute to value.
+//
+// value: The distortion is used to skew the unigram probability distribution.
+// Each weight is first raised to the distortion's power before adding to the
+// internal unigram distribution. As a result, distortion = 1.0 gives regular
+// unigram sampling (as defined by the vocab file), and distortion = 0.0 gives
+// a uniform distribution.
+// If not specified, defaults to 1
+func FixedUnigramCandidateSamplerDistortion(value float32) FixedUnigramCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["distortion"] = value
+ }
+}
+
+// FixedUnigramCandidateSamplerNumReservedIds sets the optional num_reserved_ids attribute to value.
+//
+// value: Optionally some reserved IDs can be added in the range [0,
+// ..., num_reserved_ids) by the users. One use case is that a special unknown
+// word token is used as ID 0. These IDs will have a sampling probability of 0.
+// If not specified, defaults to 0
+func FixedUnigramCandidateSamplerNumReservedIds(value int64) FixedUnigramCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["num_reserved_ids"] = value
+ }
+}
+
+// FixedUnigramCandidateSamplerNumShards sets the optional num_shards attribute to value.
+//
+// value: A sampler can be used to sample from a subset of the original range
+// in order to speed up the whole computation through parallelism. This parameter
+// (together with 'shard') indicates the number of partitions that are being
+// used in the overall computation.
+// If not specified, defaults to 1
+//
+// REQUIRES: value >= 1
+func FixedUnigramCandidateSamplerNumShards(value int64) FixedUnigramCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["num_shards"] = value
+ }
+}
+
+// FixedUnigramCandidateSamplerShard sets the optional shard attribute to value.
+//
+// value: A sampler can be used to sample from a subset of the original range
+// in order to speed up the whole computation through parallelism. This parameter
+// (together with 'num_shards') indicates the particular partition number of a
+// sampler op, when partitioning is being used.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func FixedUnigramCandidateSamplerShard(value int64) FixedUnigramCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["shard"] = value
+ }
+}
+
+// FixedUnigramCandidateSamplerUnigrams sets the optional unigrams attribute to value.
+//
+// value: A list of unigram counts or probabilities, one per ID in sequential
+// order. Exactly one of vocab_file and unigrams should be passed to this op.
+// If not specified, defaults to <>
+func FixedUnigramCandidateSamplerUnigrams(value []float32) FixedUnigramCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["unigrams"] = value
+ }
+}
+
+// FixedUnigramCandidateSamplerSeed sets the optional seed attribute to value.
+//
+// value: If either seed or seed2 are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func FixedUnigramCandidateSamplerSeed(value int64) FixedUnigramCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// FixedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value.
+//
+// value: An second seed to avoid seed collision.
+// If not specified, defaults to 0
+func FixedUnigramCandidateSamplerSeed2(value int64) FixedUnigramCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Generates labels for candidate sampling with a learned unigram distribution.
+//
+// A unigram sampler could use a fixed unigram distribution read from a
+// file or passed in as an in-memory array instead of building up the distribution
+// from data on the fly. There is also an option to skew the distribution by
+// applying a distortion power to the weights.
+//
+// The vocabulary file should be in CSV-like format, with the last field
+// being the weight associated with the word.
+//
+// For each batch, this op picks a single set of sampled candidate labels.
+//
+// The advantages of sampling candidates per-batch are simplicity and the
+// possibility of efficient dense matrix multiplication. The disadvantage is that
+// the sampled candidates must be chosen independently of the context and of the
+// true labels.
+//
+// Arguments:
+// true_classes: A batch_size * num_true matrix, in which each row contains the
+// IDs of the num_true target_classes in the corresponding original label.
+// num_true: Number of true labels per context.
+// num_sampled: Number of candidates to randomly sample.
+// unique: If unique is true, we sample with rejection, so that all sampled
+// candidates in a batch are unique. This requires some approximation to
+// estimate the post-rejection sampling probabilities.
+// range_max: The sampler will sample integers from the interval [0, range_max).
+//
+// Returns A vector of length num_sampled, in which each element is
+// the ID of a sampled candidate.A batch_size * num_true matrix, representing
+// the number of times each candidate is expected to occur in a batch
+// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
+// candidate representing the number of times the candidate is expected
+// to occur in a batch of sampled candidates. If unique=true, then this is a
+// probability.
+func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...FixedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "Asin",
+ Type: "FixedUnigramCandidateSampler",
Input: []tf.Input{
- x,
+ true_classes,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0)
+ return op.Output(0), op.Output(1), op.Output(2)
}
-// Computes the sum along sparse segments of a tensor.
-//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
-//
-// Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
-// dimension, selecting a subset of dimension 0, specified by `indices`.
-//
-// For example:
+// LogUniformCandidateSamplerAttr is an optional argument to LogUniformCandidateSampler.
+type LogUniformCandidateSamplerAttr func(optionalAttr)
+
+// LogUniformCandidateSamplerSeed sets the optional seed attribute to value.
//
-// ```python
-// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+// value: If either seed or seed2 are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func LogUniformCandidateSamplerSeed(value int64) LogUniformCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// LogUniformCandidateSamplerSeed2 sets the optional seed2 attribute to value.
//
-// # Select two rows, one segment.
-// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))
-// # => [[0 0 0 0]]
+// value: An second seed to avoid seed collision.
+// If not specified, defaults to 0
+func LogUniformCandidateSamplerSeed2(value int64) LogUniformCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Generates labels for candidate sampling with a log-uniform distribution.
//
-// # Select two rows, two segment.
-// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))
-// # => [[ 1 2 3 4]
-// # [-1 -2 -3 -4]]
+// See explanations of candidate sampling and the data formats at
+// go/candidate-sampling.
//
-// # Select all rows, two segments.
-// tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))
-// # => [[0 0 0 0]
-// # [5 6 7 8]]
+// For each batch, this op picks a single set of sampled candidate labels.
//
-// # Which is equivalent to:
-// tf.segment_sum(c, tf.constant([0, 0, 1]))
-// ```
+// The advantages of sampling candidates per-batch are simplicity and the
+// possibility of efficient dense matrix multiplication. The disadvantage is that
+// the sampled candidates must be chosen independently of the context and of the
+// true labels.
//
// Arguments:
+// true_classes: A batch_size * num_true matrix, in which each row contains the
+// IDs of the num_true target_classes in the corresponding original label.
+// num_true: Number of true labels per context.
+// num_sampled: Number of candidates to randomly sample.
+// unique: If unique is true, we sample with rejection, so that all sampled
+// candidates in a batch are unique. This requires some approximation to
+// estimate the post-rejection sampling probabilities.
+// range_max: The sampler will sample integers from the interval [0, range_max).
//
-// indices: A 1-D tensor. Has same rank as `segment_ids`.
-// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
-//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `k`, the number of segments.
-func SparseSegmentSum(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) {
+// Returns A vector of length num_sampled, in which each element is
+// the ID of a sampled candidate.A batch_size * num_true matrix, representing
+// the number of times each candidate is expected to occur in a batch
+// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
+// candidate representing the number of times the candidate is expected
+// to occur in a batch of sampled candidates. If unique=true, then this is a
+// probability.
+func LogUniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LogUniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "SparseSegmentSum",
+ Type: "LogUniformCandidateSampler",
Input: []tf.Input{
- data, indices, segment_ids,
+ true_classes,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0)
+ return op.Output(0), op.Output(1), op.Output(2)
}
-// Computes hyperbolic sine of x element-wise.
-func Sinh(scope *Scope, x tf.Output) (y tf.Output) {
+// UniformCandidateSamplerAttr is an optional argument to UniformCandidateSampler.
+type UniformCandidateSamplerAttr func(optionalAttr)
+
+// UniformCandidateSamplerSeed sets the optional seed attribute to value.
+//
+// value: If either seed or seed2 are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func UniformCandidateSamplerSeed(value int64) UniformCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// UniformCandidateSamplerSeed2 sets the optional seed2 attribute to value.
+//
+// value: An second seed to avoid seed collision.
+// If not specified, defaults to 0
+func UniformCandidateSamplerSeed2(value int64) UniformCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Generates labels for candidate sampling with a uniform distribution.
+//
+// See explanations of candidate sampling and the data formats at
+// go/candidate-sampling.
+//
+// For each batch, this op picks a single set of sampled candidate labels.
+//
+// The advantages of sampling candidates per-batch are simplicity and the
+// possibility of efficient dense matrix multiplication. The disadvantage is that
+// the sampled candidates must be chosen independently of the context and of the
+// true labels.
+//
+// Arguments:
+// true_classes: A batch_size * num_true matrix, in which each row contains the
+// IDs of the num_true target_classes in the corresponding original label.
+// num_true: Number of true labels per context.
+// num_sampled: Number of candidates to randomly sample.
+// unique: If unique is true, we sample with rejection, so that all sampled
+// candidates in a batch are unique. This requires some approximation to
+// estimate the post-rejection sampling probabilities.
+// range_max: The sampler will sample integers from the interval [0, range_max).
+//
+// Returns A vector of length num_sampled, in which each element is
+// the ID of a sampled candidate.A batch_size * num_true matrix, representing
+// the number of times each candidate is expected to occur in a batch
+// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
+// candidate representing the number of times the candidate is expected
+// to occur in a batch of sampled candidates. If unique=true, then this is a
+// probability.
+func UniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...UniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "Sinh",
+ Type: "UniformCandidateSampler",
Input: []tf.Input{
- x,
+ true_classes,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0)
+ return op.Output(0), op.Output(1), op.Output(2)
}
-// Computes the minimum along segments of a tensor.
+// GenerateVocabRemappingAttr is an optional argument to GenerateVocabRemapping.
+type GenerateVocabRemappingAttr func(optionalAttr)
+
+// GenerateVocabRemappingOldVocabSize sets the optional old_vocab_size attribute to value.
//
-// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
-// segments.
+// value: Number of entries in the old vocab file to consider. If -1,
+// use the entire old vocabulary.
+// If not specified, defaults to -1
//
-// This operator is similar to the unsorted segment sum operator found
-// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
-// Instead of computing the sum over segments, it computes the minimum such that:
+// REQUIRES: value >= -1
+func GenerateVocabRemappingOldVocabSize(value int64) GenerateVocabRemappingAttr {
+ return func(m optionalAttr) {
+ m["old_vocab_size"] = value
+ }
+}
+
+// Given a path to new and old vocabulary files, returns a remapping Tensor of
//
-// \\(output_i = \min_j data_j\\) where min is over `j` such
-// that `segment_ids[j] == i`.
+// length `num_new_vocab`, where `remapping[i]` contains the row number in the old
+// vocabulary that corresponds to row `i` in the new vocabulary (starting at line
+// `new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i`
+// in the new vocabulary is not in the old vocabulary. The old vocabulary is
+// constrained to the first `old_vocab_size` entries if `old_vocab_size` is not the
+// default value of -1.
//
-// If the minimum is empty for a given segment ID `i`, it outputs the largest
-// possible value for the specific numeric type,
-// `output[i] = numeric_limits<T>::max()`.
+// `num_vocab_offset` enables
+// use in the partitioned variable case, and should generally be set through
+// examining partitioning info. The format of the files should be a text file,
+// with each line containing a single entity within the vocabulary.
//
-// Arguments:
+// For example, with `new_vocab_file` a text file containing each of the following
+// elements on a single line: `[f0, f1, f2, f3]`, old_vocab_file = [f1, f0, f3],
+// `num_new_vocab = 3, new_vocab_offset = 1`, the returned remapping would be
+// `[0, -1, 2]`.
//
-// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
-// first dimension.
+// The op also returns a count of how many entries in the new vocabulary
+// were present in the old vocabulary, which is used to calculate the number of
+// values to initialize in a weight matrix remapping
//
+// This functionality can be used to remap both row vocabularies (typically,
+// features) and column vocabularies (typically, classes) from TensorFlow
+// checkpoints. Note that the partitioning logic relies on contiguous vocabularies
+// corresponding to div-partitioned variables. Moreover, the underlying remapping
+// uses an IndexTable (as opposed to an inexact CuckooTable), so client code should
+// use the corresponding index_table_from_file() as the FeatureColumn framework
+// does (as opposed to tf.feature_to_id(), which uses a CuckooTable).
//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `num_segments`.
-func UnsortedSegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+// Arguments:
+// new_vocab_file: Path to the new vocab file.
+// old_vocab_file: Path to the old vocab file.
+// new_vocab_offset: How many entries into the new vocab file to start reading.
+// num_new_vocab: Number of entries in the new vocab file to remap.
+//
+// Returns A Tensor of length num_new_vocab where the element at index i
+// is equal to the old ID that maps to the new ID i. This element is -1 for any
+// new ID that is not found in the old vocabulary.Number of new vocab entries found in old vocab.
+func GenerateVocabRemapping(scope *Scope, new_vocab_file tf.Output, old_vocab_file tf.Output, new_vocab_offset int64, num_new_vocab int64, optional ...GenerateVocabRemappingAttr) (remapping tf.Output, num_present tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{"new_vocab_offset": new_vocab_offset, "num_new_vocab": num_new_vocab}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "UnsortedSegmentMin",
+ Type: "GenerateVocabRemapping",
Input: []tf.Input{
- data, segment_ids, num_segments,
+ new_vocab_file, old_vocab_file,
},
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
+// Broadcasts a tensor value to one or more other devices.
+func CollectiveBcastSend(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape}
+ opspec := tf.OpSpec{
+ Type: "CollectiveBcastSend",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// Computes rectified linear 6: `min(max(features, 0), 6)`.
-func Relu6(scope *Scope, features tf.Output) (activations tf.Output) {
+// Mutually reduces multiple tensors of identical type and shape.
+func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, merge_op string, final_op string, subdiv_offsets []int64) (data tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "merge_op": merge_op, "final_op": final_op, "subdiv_offsets": subdiv_offsets}
opspec := tf.OpSpec{
- Type: "Relu6",
+ Type: "CollectiveReduce",
Input: []tf.Input{
- features,
+ input,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// Computes the sum along segments of a tensor.
+// AbortAttr is an optional argument to Abort.
+type AbortAttr func(optionalAttr)
+
+// AbortErrorMsg sets the optional error_msg attribute to value.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// value: A string which is the message associated with the exception.
+// If not specified, defaults to ""
+func AbortErrorMsg(value string) AbortAttr {
+ return func(m optionalAttr) {
+ m["error_msg"] = value
+ }
+}
+
+// AbortExitWithoutError sets the optional exit_without_error attribute to value.
+// If not specified, defaults to false
+func AbortExitWithoutError(value bool) AbortAttr {
+ return func(m optionalAttr) {
+ m["exit_without_error"] = value
+ }
+}
+
+// Raise a exception to abort the process when called.
//
-// Computes a tensor such that
-// \\(output[i] = sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
-// that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids`
-// need not be sorted and need not cover all values in the full
-// range of valid values.
+// If exit_without_error is true, the process will exit normally,
+// otherwise it will exit with a SIGABORT signal.
//
-// If the sum is empty for a given segment ID `i`, `output[i] = 0`.
-// If the given segment ID `i` is negative, the value is dropped and will not be
-// added to the sum of the segment.
+// Returns nothing but an exception.
//
-// `num_segments` should equal the number of distinct segment IDs.
+// Returns the created operation.
+func Abort(scope *Scope, optional ...AbortAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Abort",
+
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Forwards the input to the output.
//
-// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
-// <img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentSum.png" alt>
-// </div>
+// This operator represents the loop termination condition used by the
+// "pivot" switches of a loop.
//
// Arguments:
+// input: A boolean scalar, representing the branch predicate of the Switch op.
//
-// segment_ids: A tensor whose shape is a prefix of `data.shape`.
+// Returns The same tensor as `input`.
+func LoopCond(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "LoopCond",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns a tensor of zeros with the same shape and type as x.
//
+// Arguments:
+// x: a tensor of type T.
//
-// Returns Has same shape as data, except for the first `segment_ids.rank`
-// dimensions, which are replaced with a single dimension which has size
-// `num_segments`.
-func UnsortedSegmentSum(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+// Returns a tensor of the same shape and type as x but filled with zeros.
+func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "UnsortedSegmentSum",
+ Type: "ZerosLike",
Input: []tf.Input{
- data, segment_ids, num_segments,
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns a copy of the input tensor.
+func Snapshot(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Snapshot",
+ Input: []tf.Input{
+ input,
},
}
op := scope.AddOperation(opspec)
@@ -4116,6 +4750,162 @@ func SlideDataset(scope *Scope, input_dataset tf.Output, window_size tf.Output,
return op.Output(0)
}
+// EditDistanceAttr is an optional argument to EditDistance.
+type EditDistanceAttr func(optionalAttr)
+
+// EditDistanceNormalize sets the optional normalize attribute to value.
+//
+// value: boolean (if true, edit distances are normalized by length of truth).
+//
+// The output is:
+// If not specified, defaults to true
+func EditDistanceNormalize(value bool) EditDistanceAttr {
+ return func(m optionalAttr) {
+ m["normalize"] = value
+ }
+}
+
+// Computes the (possibly normalized) Levenshtein Edit Distance.
+//
+// The inputs are variable-length sequences provided by SparseTensors
+// (hypothesis_indices, hypothesis_values, hypothesis_shape)
+// and
+// (truth_indices, truth_values, truth_shape).
+//
+// The inputs are:
+//
+// Arguments:
+// hypothesis_indices: The indices of the hypothesis list SparseTensor.
+// This is an N x R int64 matrix.
+// hypothesis_values: The values of the hypothesis list SparseTensor.
+// This is an N-length vector.
+// hypothesis_shape: The shape of the hypothesis list SparseTensor.
+// This is an R-length vector.
+// truth_indices: The indices of the truth list SparseTensor.
+// This is an M x R int64 matrix.
+// truth_values: The values of the truth list SparseTensor.
+// This is an M-length vector.
+// truth_shape: truth indices, vector.
+//
+// Returns A dense float tensor with rank R - 1.
+//
+// For the example input:
+//
+// // hypothesis represents a 2x1 matrix with variable-length values:
+// // (0,0) = ["a"]
+// // (1,0) = ["b"]
+// hypothesis_indices = [[0, 0, 0],
+// [1, 0, 0]]
+// hypothesis_values = ["a", "b"]
+// hypothesis_shape = [2, 1, 1]
+//
+// // truth represents a 2x2 matrix with variable-length values:
+// // (0,0) = []
+// // (0,1) = ["a"]
+// // (1,0) = ["b", "c"]
+// // (1,1) = ["a"]
+// truth_indices = [[0, 1, 0],
+// [1, 0, 0],
+// [1, 0, 1],
+// [1, 1, 0]]
+// truth_values = ["a", "b", "c", "a"]
+// truth_shape = [2, 2, 2]
+// normalize = true
+//
+// The output will be:
+//
+// // output is a 2x2 matrix with edit distances normalized by truth lengths.
+// output = [[inf, 1.0], // (0,0): no truth, (0,1): no hypothesis
+// [0.5, 1.0]] // (1,0): addition, (1,1): no hypothesis
+func EditDistance(scope *Scope, hypothesis_indices tf.Output, hypothesis_values tf.Output, hypothesis_shape tf.Output, truth_indices tf.Output, truth_values tf.Output, truth_shape tf.Output, optional ...EditDistanceAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "EditDistance",
+ Input: []tf.Input{
+ hypothesis_indices, hypothesis_values, hypothesis_shape, truth_indices, truth_values, truth_shape,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// DepthwiseConv2dNativeBackpropInputAttr is an optional argument to DepthwiseConv2dNativeBackpropInput.
+type DepthwiseConv2dNativeBackpropInputAttr func(optionalAttr)
+
+// DepthwiseConv2dNativeBackpropInputDataFormat sets the optional data_format attribute to value.
+//
+// value: Specify the data format of the input and output data. With the
+// default format "NHWC", the data is stored in the order of:
+// [batch, height, width, channels].
+// Alternatively, the format could be "NCHW", the data storage order of:
+// [batch, channels, height, width].
+// If not specified, defaults to "NHWC"
+func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dNativeBackpropInputAttr {
+ return func(m optionalAttr) {
+ m["data_format"] = value
+ }
+}
+
+// DepthwiseConv2dNativeBackpropInputDilations sets the optional dilations attribute to value.
+//
+// value: 1-D tensor of length 4. The dilation factor for each dimension of
+// `input`. If set to k > 1, there will be k-1 skipped cells between each filter
+// element on that dimension. The dimension order is determined by the value of
+// `data_format`, see above for details. Dilations in the batch and depth
+// dimensions must be 1.
+// If not specified, defaults to <i:1 i:1 i:1 i:1 >
+func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr {
+ return func(m optionalAttr) {
+ m["dilations"] = value
+ }
+}
+
+// Computes the gradients of depthwise convolution with respect to the input.
+//
+// Arguments:
+// input_sizes: An integer vector representing the shape of `input`, based
+// on `data_format`. For example, if `data_format` is 'NHWC' then
+// `input` is a 4-D `[batch, height, width, channels]` tensor.
+// filter: 4-D with shape
+// `[filter_height, filter_width, in_channels, depthwise_multiplier]`.
+// out_backprop: 4-D with shape based on `data_format`.
+// For example, if `data_format` is 'NHWC' then
+// out_backprop shape is `[batch, out_height, out_width, out_channels]`.
+// Gradients w.r.t. the output of the convolution.
+// strides: The stride of the sliding window for each dimension of the input
+// of the convolution.
+// padding: The type of padding algorithm to use.
+//
+// Returns 4-D with shape according to `data_format`. For example, if
+// `data_format` is 'NHWC', output shape is `[batch, in_height,
+// in_width, in_channels]`. Gradient w.r.t. the input of the
+// convolution.
+func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropInputAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"strides": strides, "padding": padding}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "DepthwiseConv2dNativeBackpropInput",
+ Input: []tf.Input{
+ input_sizes, filter, out_backprop,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// ApproximateEqualAttr is an optional argument to ApproximateEqual.
type ApproximateEqualAttr func(optionalAttr)
@@ -4284,124 +5074,90 @@ func SparseReduceSumSparse(scope *Scope, input_indices tf.Output, input_values t
return op.Output(0), op.Output(1), op.Output(2)
}
-// Returns x + y element-wise.
+// AllCandidateSamplerAttr is an optional argument to AllCandidateSampler.
+type AllCandidateSamplerAttr func(optionalAttr)
+
+// AllCandidateSamplerSeed sets the optional seed attribute to value.
//
-// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func AddV2(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "AddV2",
- Input: []tf.Input{
- x, y,
- },
+// value: If either seed or seed2 are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func AllCandidateSamplerSeed(value int64) AllCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
}
- op := scope.AddOperation(opspec)
- return op.Output(0)
}
-// NthElementAttr is an optional argument to NthElement.
-type NthElementAttr func(optionalAttr)
-
-// NthElementReverse sets the optional reverse attribute to value.
+// AllCandidateSamplerSeed2 sets the optional seed2 attribute to value.
//
-// value: When set to True, find the nth-largest value in the vector and vice
-// versa.
-// If not specified, defaults to false
-func NthElementReverse(value bool) NthElementAttr {
+// value: An second seed to avoid seed collision.
+// If not specified, defaults to 0
+func AllCandidateSamplerSeed2(value int64) AllCandidateSamplerAttr {
return func(m optionalAttr) {
- m["reverse"] = value
+ m["seed2"] = value
}
}
-// Finds values of the `n`-th order statistic for the last dimension.
+// Generates labels for candidate sampling with a learned unigram distribution.
//
-// If the input is a vector (rank-1), finds the entries which is the nth-smallest
-// value in the vector and outputs their values as scalar tensor.
+// See explanations of candidate sampling and the data formats at
+// go/candidate-sampling.
//
-// For matrices (resp. higher rank input), computes the entries which is the
-// nth-smallest value in each row (resp. vector along the last dimension). Thus,
+// For each batch, this op picks a single set of sampled candidate labels.
//
-// values.shape = input.shape[:-1]
+// The advantages of sampling candidates per-batch are simplicity and the
+// possibility of efficient dense matrix multiplication. The disadvantage is that
+// the sampled candidates must be chosen independently of the context and of the
+// true labels.
//
// Arguments:
-// input: 1-D or higher with last dimension at least `n+1`.
-// n: 0-D. Position of sorted vector to select along the last dimension (along
-// each row for matrices). Valid range of n is `[0, input.shape[:-1])`
+// true_classes: A batch_size * num_true matrix, in which each row contains the
+// IDs of the num_true target_classes in the corresponding original label.
+// num_true: Number of true labels per context.
+// num_sampled: Number of candidates to produce.
+// unique: If unique is true, we sample with rejection, so that all sampled
+// candidates in a batch are unique. This requires some approximation to
+// estimate the post-rejection sampling probabilities.
//
-// Returns The `n`-th order statistic along each last dimensional slice.
-func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthElementAttr) (values tf.Output) {
+// Returns A vector of length num_sampled, in which each element is
+// the ID of a sampled candidate.A batch_size * num_true matrix, representing
+// the number of times each candidate is expected to occur in a batch
+// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
+// candidate representing the number of times the candidate is expected
+// to occur in a batch of sampled candidates. If unique=true, then this is a
+// probability.
+func AllCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, optional ...AllCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{}
+ attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique}
for _, a := range optional {
a(attrs)
}
opspec := tf.OpSpec{
- Type: "NthElement",
+ Type: "AllCandidateSampler",
Input: []tf.Input{
- input, n,
+ true_classes,
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0)
+ return op.Output(0), op.Output(1), op.Output(2)
}
-// Computes the maximum along segments of a tensor.
-//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
-//
-// This operator is similar to the unsorted segment sum operator found
-// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
-// Instead of computing the sum over segments, it computes the maximum such that:
-//
-// \\(output_i = \max_j data_j\\) where max is over `j` such
-// that `segment_ids[j] == i`.
-//
-// If the maximum is empty for a given segment ID `i`, it outputs the smallest
-// possible value for the specific numeric type,
-// `output[i] = numeric_limits<T>::lowest()`.
-//
-// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
-// <img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt>
-// </div>
-//
-// Arguments:
-//
-// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
-// first dimension.
-//
+// Returns x + y element-wise.
//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `num_segments`.
-func UnsortedSegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "UnsortedSegmentMax",
- Input: []tf.Input{
- data, segment_ids, num_segments,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes exponential of x element-wise. \\(y = e^x\\).
-func Exp(scope *Scope, x tf.Output) (y tf.Output) {
+// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func AddV2(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "Exp",
+ Type: "AddV2",
Input: []tf.Input{
- x,
+ x, y,
},
}
op := scope.AddOperation(opspec)
@@ -4500,6 +5256,120 @@ func Requantize(scope *Scope, input tf.Output, input_min tf.Output, input_max tf
return op.Output(0), op.Output(1), op.Output(2)
}
+// PreventGradientAttr is an optional argument to PreventGradient.
+type PreventGradientAttr func(optionalAttr)
+
+// PreventGradientMessage sets the optional message attribute to value.
+//
+// value: Will be printed in the error when anyone tries to differentiate
+// this operation.
+// If not specified, defaults to ""
+func PreventGradientMessage(value string) PreventGradientAttr {
+ return func(m optionalAttr) {
+ m["message"] = value
+ }
+}
+
+// An identity op that triggers an error if a gradient is requested.
+//
+// When executed in a graph, this op outputs its input tensor as-is.
+//
+// When building ops to compute gradients, the TensorFlow gradient system
+// will return an error when trying to lookup the gradient of this op,
+// because no gradient must ever be registered for this function. This
+// op exists to prevent subtle bugs from silently returning unimplemented
+// gradients in some corner cases.
+//
+// Arguments:
+// input: any tensor.
+//
+// Returns the same input tensor.
+func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "PreventGradient",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes asin of x element-wise.
+func Asin(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Asin",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes the sum along sparse segments of a tensor.
+//
+// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is
+// misisng, the `output` tensor at that position will be zeroed.
+//
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
+//
+// For example:
+//
+// ```python
+// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+//
+// tf.sparse_segment_sum_with_num_segments(
+// c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3)
+// # => [[0 0 0 0]
+// # [0 0 0 0]
+// # [0 0 0 0]]
+//
+// tf.sparse_segment_sum_with_num_segments(c,
+// tf.constant([0, 1]),
+// tf.constant([0, 2],
+// num_segments=4))
+// # => [[ 1 2 3 4]
+// # [ 0 0 0 0]
+// # [-1 -2 -3 -4]
+// # [ 0 0 0 0]]
+// ```
+//
+// Arguments:
+//
+// indices: A 1-D tensor. Has same rank as `segment_ids`.
+// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+// num_segments: Should equal the number of distinct segment IDs.
+//
+// Returns Has same shape as data, except for dimension 0 which
+// has size `num_segments`.
+func SparseSegmentSumWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseSegmentSumWithNumSegments",
+ Input: []tf.Input{
+ data, indices, segment_ids, num_segments,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the determinant of one or more square matrices.
//
// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
@@ -4724,6 +5594,74 @@ func TensorSliceDataset(scope *Scope, components []tf.Output, output_shapes []tf
return op.Output(0)
}
+// Computes hyperbolic sine of x element-wise.
+func Sinh(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Sinh",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes the sum along sparse segments of a tensor.
+//
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
+//
+// Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
+// dimension, selecting a subset of dimension 0, specified by `indices`.
+//
+// For example:
+//
+// ```python
+// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+//
+// # Select two rows, one segment.
+// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))
+// # => [[0 0 0 0]]
+//
+// # Select two rows, two segment.
+// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))
+// # => [[ 1 2 3 4]
+// # [-1 -2 -3 -4]]
+//
+// # Select all rows, two segments.
+// tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))
+// # => [[0 0 0 0]
+// # [5 6 7 8]]
+//
+// # Which is equivalent to:
+// tf.segment_sum(c, tf.constant([0, 0, 1]))
+// ```
+//
+// Arguments:
+//
+// indices: A 1-D tensor. Has same rank as `segment_ids`.
+// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+//
+// Returns Has same shape as data, except for dimension 0 which
+// has size `k`, the number of segments.
+func SparseSegmentSum(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseSegmentSum",
+ Input: []tf.Input{
+ data, indices, segment_ids,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes natural logarithm of (1 + x) element-wise.
//
// I.e., \\(y = \log_e (1 + x)\\).
@@ -5225,6 +6163,47 @@ func Reciprocal(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
+// Transforms `input_dataset` containing `Example` protos as vectors of DT_STRING into a dataset of `Tensor` or `SparseTensor` objects representing the parsed features.
+//
+// Arguments:
+//
+//
+// dense_defaults: A dict mapping string keys to `Tensor`s.
+// The keys of the dict must match the dense_keys of the feature.
+// sparse_keys: A list of string keys in the examples features.
+// The results for these keys will be returned as `SparseTensor` objects.
+// dense_keys: A list of Ndense string Tensors (scalars).
+// The keys expected in the Examples features associated with dense values.
+// sparse_types: A list of `DTypes` of the same length as `sparse_keys`.
+// Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
+// and `tf.string` (`BytesList`) are supported.
+// dense_shapes: List of tuples with the same length as `dense_keys`.
+// The shape of the data for each dense feature referenced by `dense_keys`.
+// Required for any input tensors identified by `dense_keys`. Must be
+// either fully defined, or may contain an unknown first dimension.
+// An unknown first dimension means the feature is treated as having
+// a variable number of blocks, and the output shape along this dimension
+// is considered unknown at graph build time. Padding is applied for
+// minibatch elements smaller than the maximum number of blocks for the
+// given feature along this dimension.
+// output_types: The type list for the return values.
+// output_shapes: The list of shapes being produced.
+func ParseExampleDataset(scope *Scope, input_dataset tf.Output, num_parallel_calls tf.Output, dense_defaults []tf.Output, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes, "output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "ParseExampleDataset",
+ Input: []tf.Input{
+ input_dataset, num_parallel_calls, tf.OutputList(dense_defaults),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Returns a batched matrix tensor with new batched diagonal values.
//
// Given `input` and `diagonal`, this operation returns a tensor with the
@@ -5416,26 +6395,6 @@ func LogicalAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
-// Checks whether a tree ensemble has been initialized.
-//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble resouce.
-//
-// Returns output boolean on whether it is initialized or not.
-func IsBoostedTreesEnsembleInitialized(scope *Scope, tree_ensemble_handle tf.Output) (is_initialized tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "IsBoostedTreesEnsembleInitialized",
- Input: []tf.Input{
- tree_ensemble_handle,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// CastAttr is an optional argument to Cast.
type CastAttr func(optionalAttr)
@@ -5619,6 +6578,44 @@ func QuantizedAvgPool(scope *Scope, input tf.Output, min_input tf.Output, max_in
return op.Output(0), op.Output(1), op.Output(2)
}
+// Extract `patches` from `input` and put them in the "depth" output
+// dimension. 3D extension of `extract_image_patches`.
+//
+// Arguments:
+// input: 5-D Tensor with shape `[batch, in_planes, in_rows, in_cols, depth]`.
+// ksizes: The size of the sliding window for each dimension of `input`.
+// strides: 1-D of length 5. How far the centers of two consecutive patches are in
+// `input`. Must be: `[1, stride_planes, stride_rows, stride_cols, 1]`.
+// padding: The type of padding algorithm to use.
+//
+// We specify the size-related attributes as:
+//
+// ```python
+// ksizes = [1, ksize_planes, ksize_rows, ksize_cols, 1]
+// strides = [1, stride_planes, strides_rows, strides_cols, 1]
+// ```
+//
+// Returns 5-D Tensor with shape `[batch, out_planes, out_rows, out_cols,
+// ksize_planes * ksize_rows * ksize_cols * depth]` containing patches
+// with size `ksize_planes x ksize_rows x ksize_cols x depth` vectorized
+// in the "depth" dimension. Note `out_planes`, `out_rows` and `out_cols`
+// are the dimensions of the output patches.
+func ExtractVolumePatches(scope *Scope, input tf.Output, ksizes []int64, strides []int64, padding string) (patches tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"ksizes": ksizes, "strides": strides, "padding": padding}
+ opspec := tf.OpSpec{
+ Type: "ExtractVolumePatches",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// FractionalAvgPoolAttr is an optional argument to FractionalAvgPool.
type FractionalAvgPoolAttr func(optionalAttr)
@@ -6189,6 +7186,98 @@ func ResourceSparseApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output,
return scope.AddOperation(opspec)
}
+// Gets next element for the provided shard number.
+//
+// Arguments:
+// multi_device_iterator: A MultiDeviceIterator resource.
+// shard_num: Integer representing which shard to fetch data for.
+// incarnation_id: Which incarnation of the MultiDeviceIterator is running.
+// output_types: The type list for the return values.
+// output_shapes: The list of shapes being produced.
+//
+// Returns Result of the get_next on the dataset.
+func MultiDeviceIteratorGetNextFromShard(scope *Scope, multi_device_iterator tf.Output, shard_num tf.Output, incarnation_id tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "MultiDeviceIteratorGetNextFromShard",
+ Input: []tf.Input{
+ multi_device_iterator, shard_num, incarnation_id,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if components, idx, err = makeOutputList(op, idx, "components"); err != nil {
+ scope.UpdateErr("MultiDeviceIteratorGetNextFromShard", err)
+ return
+ }
+ return components
+}
+
+// Computes rectified linear 6: `min(max(features, 0), 6)`.
+func Relu6(scope *Scope, features tf.Output) (activations tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Relu6",
+ Input: []tf.Input{
+ features,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes the minimum along segments of a tensor.
+//
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+// for an explanation of segments.
+//
+// This operator is similar to the unsorted segment sum operator found
+// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
+// Instead of computing the sum over segments, it computes the minimum such that:
+//
+// \\(output_i = \min_{j...} data_[j...]\\) where min is over tuples `j...` such
+// that `segment_ids[j...] == i`.
+//
+// If the minimum is empty for a given segment ID `i`, it outputs the largest
+// possible value for the specific numeric type,
+// `output[i] = numeric_limits<T>::max()`.
+//
+// If the given segment ID `i` is negative, then the corresponding value is
+// dropped, and will not be included in the result.
+//
+// Arguments:
+//
+// segment_ids: A tensor whose shape is a prefix of `data.shape`.
+//
+//
+// Returns Has same shape as data, except for the first `segment_ids.rank`
+// dimensions, which are replaced with a single dimension which has size
+// `num_segments`.
+func UnsortedSegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "UnsortedSegmentMin",
+ Input: []tf.Input{
+ data, segment_ids, num_segments,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes rectified linear gradients for a Relu operation.
//
// Arguments:
@@ -6476,7 +7565,7 @@ func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset
return offset
}
-// Compute the lower regularized incomplete Gamma function `Q(a, x)`.
+// Compute the lower regularized incomplete Gamma function `P(a, x)`.
//
// The lower regularized incomplete Gamma function is defined as:
//
@@ -7162,6 +8251,44 @@ func BiasAddGrad(scope *Scope, out_backprop tf.Output, optional ...BiasAddGradAt
return op.Output(0)
}
+// Bucketizes 'input' based on 'boundaries'.
+//
+// For example, if the inputs are
+// boundaries = [0, 10, 100]
+// input = [[-5, 10000]
+// [150, 10]
+// [5, 100]]
+//
+// then the output will be
+// output = [[0, 3]
+// [3, 2]
+// [1, 3]]
+//
+// Arguments:
+// input: Any shape of Tensor contains with int or float type.
+// boundaries: A sorted list of floats gives the boundary of the buckets.
+//
+// Returns Same shape with 'input', each value of input replaced with bucket index.
+//
+// @compatibility(numpy)
+// Equivalent to np.digitize.
+// @end_compatibility
+func Bucketize(scope *Scope, input tf.Output, boundaries []float32) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"boundaries": boundaries}
+ opspec := tf.OpSpec{
+ Type: "Bucketize",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// FusedBatchNormV2Attr is an optional argument to FusedBatchNormV2.
type FusedBatchNormV2Attr func(optionalAttr)
@@ -7910,6 +9037,214 @@ func QueueDequeueV2(scope *Scope, handle tf.Output, component_types []tf.DataTyp
return components
}
+// ParseSequenceExampleAttr is an optional argument to ParseSequenceExample.
+type ParseSequenceExampleAttr func(optionalAttr)
+
+// ParseSequenceExampleNcontextSparse sets the optional Ncontext_sparse attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func ParseSequenceExampleNcontextSparse(value int64) ParseSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["Ncontext_sparse"] = value
+ }
+}
+
+// ParseSequenceExampleNcontextDense sets the optional Ncontext_dense attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func ParseSequenceExampleNcontextDense(value int64) ParseSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["Ncontext_dense"] = value
+ }
+}
+
+// ParseSequenceExampleNfeatureListSparse sets the optional Nfeature_list_sparse attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func ParseSequenceExampleNfeatureListSparse(value int64) ParseSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["Nfeature_list_sparse"] = value
+ }
+}
+
+// ParseSequenceExampleNfeatureListDense sets the optional Nfeature_list_dense attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func ParseSequenceExampleNfeatureListDense(value int64) ParseSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["Nfeature_list_dense"] = value
+ }
+}
+
+// ParseSequenceExampleContextSparseTypes sets the optional context_sparse_types attribute to value.
+//
+// value: A list of Ncontext_sparse types; the data types of data in
+// each context Feature given in context_sparse_keys.
+// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
+// DT_INT64 (Int64List), and DT_STRING (BytesList).
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSequenceExampleContextSparseTypes(value []tf.DataType) ParseSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["context_sparse_types"] = value
+ }
+}
+
+// ParseSequenceExampleFeatureListDenseTypes sets the optional feature_list_dense_types attribute to value.
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSequenceExampleFeatureListDenseTypes(value []tf.DataType) ParseSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["feature_list_dense_types"] = value
+ }
+}
+
+// ParseSequenceExampleContextDenseShapes sets the optional context_dense_shapes attribute to value.
+//
+// value: A list of Ncontext_dense shapes; the shapes of data in
+// each context Feature given in context_dense_keys.
+// The number of elements in the Feature corresponding to context_dense_key[j]
+// must always equal context_dense_shapes[j].NumEntries().
+// The shape of context_dense_values[j] will match context_dense_shapes[j].
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSequenceExampleContextDenseShapes(value []tf.Shape) ParseSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["context_dense_shapes"] = value
+ }
+}
+
+// ParseSequenceExampleFeatureListSparseTypes sets the optional feature_list_sparse_types attribute to value.
+//
+// value: A list of Nfeature_list_sparse types; the data types
+// of data in each FeatureList given in feature_list_sparse_keys.
+// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
+// DT_INT64 (Int64List), and DT_STRING (BytesList).
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSequenceExampleFeatureListSparseTypes(value []tf.DataType) ParseSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["feature_list_sparse_types"] = value
+ }
+}
+
+// ParseSequenceExampleFeatureListDenseShapes sets the optional feature_list_dense_shapes attribute to value.
+//
+// value: A list of Nfeature_list_dense shapes; the shapes of
+// data in each FeatureList given in feature_list_dense_keys.
+// The shape of each Feature in the FeatureList corresponding to
+// feature_list_dense_key[j] must always equal
+// feature_list_dense_shapes[j].NumEntries().
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSequenceExampleFeatureListDenseShapes(value []tf.Shape) ParseSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["feature_list_dense_shapes"] = value
+ }
+}
+
+// Transforms a vector of brain.SequenceExample protos (as strings) into typed tensors.
+//
+// Arguments:
+// serialized: A vector containing binary serialized SequenceExample protos.
+// debug_name: A vector containing the names of the serialized protos.
+// May contain, for example, table key (descriptive) name for the
+// corresponding serialized proto. This is purely useful for debugging
+// purposes, and the presence of values here has no effect on the output.
+// May also be an empty vector if no name is available.
+// context_dense_defaults: A list of Ncontext_dense Tensors (some may be empty).
+// context_dense_defaults[j] provides default values
+// when the SequenceExample's context map lacks context_dense_key[j].
+// If an empty Tensor is provided for context_dense_defaults[j],
+// then the Feature context_dense_keys[j] is required.
+// The input type is inferred from context_dense_defaults[j], even when it's
+// empty. If context_dense_defaults[j] is not empty, its shape must match
+// context_dense_shapes[j].
+// feature_list_dense_missing_assumed_empty: A vector listing the
+// FeatureList keys which may be missing from the SequenceExamples. If the
+// associated FeatureList is missing, it is treated as empty. By default,
+// any FeatureList not listed in this vector must exist in the SequenceExamples.
+// context_sparse_keys: A list of Ncontext_sparse string Tensors (scalars).
+// The keys expected in the Examples' features associated with context_sparse
+// values.
+// context_dense_keys: A list of Ncontext_dense string Tensors (scalars).
+// The keys expected in the SequenceExamples' context features associated with
+// dense values.
+// feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors
+// (scalars). The keys expected in the FeatureLists associated with sparse
+// values.
+// feature_list_dense_keys: A list of Nfeature_list_dense string Tensors (scalars).
+// The keys expected in the SequenceExamples' feature_lists associated
+// with lists of dense values.
+func ParseSequenceExample(scope *Scope, serialized tf.Output, debug_name tf.Output, context_dense_defaults []tf.Output, feature_list_dense_missing_assumed_empty []string, context_sparse_keys []string, context_dense_keys []string, feature_list_sparse_keys []string, feature_list_dense_keys []string, optional ...ParseSequenceExampleAttr) (context_sparse_indices []tf.Output, context_sparse_values []tf.Output, context_sparse_shapes []tf.Output, context_dense_values []tf.Output, feature_list_sparse_indices []tf.Output, feature_list_sparse_values []tf.Output, feature_list_sparse_shapes []tf.Output, feature_list_dense_values []tf.Output, feature_list_dense_lengths []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"feature_list_dense_missing_assumed_empty": feature_list_dense_missing_assumed_empty, "context_sparse_keys": context_sparse_keys, "context_dense_keys": context_dense_keys, "feature_list_sparse_keys": feature_list_sparse_keys, "feature_list_dense_keys": feature_list_dense_keys}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ParseSequenceExample",
+ Input: []tf.Input{
+ serialized, debug_name, tf.OutputList(context_dense_defaults),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if context_sparse_indices, idx, err = makeOutputList(op, idx, "context_sparse_indices"); err != nil {
+ scope.UpdateErr("ParseSequenceExample", err)
+ return
+ }
+ if context_sparse_values, idx, err = makeOutputList(op, idx, "context_sparse_values"); err != nil {
+ scope.UpdateErr("ParseSequenceExample", err)
+ return
+ }
+ if context_sparse_shapes, idx, err = makeOutputList(op, idx, "context_sparse_shapes"); err != nil {
+ scope.UpdateErr("ParseSequenceExample", err)
+ return
+ }
+ if context_dense_values, idx, err = makeOutputList(op, idx, "context_dense_values"); err != nil {
+ scope.UpdateErr("ParseSequenceExample", err)
+ return
+ }
+ if feature_list_sparse_indices, idx, err = makeOutputList(op, idx, "feature_list_sparse_indices"); err != nil {
+ scope.UpdateErr("ParseSequenceExample", err)
+ return
+ }
+ if feature_list_sparse_values, idx, err = makeOutputList(op, idx, "feature_list_sparse_values"); err != nil {
+ scope.UpdateErr("ParseSequenceExample", err)
+ return
+ }
+ if feature_list_sparse_shapes, idx, err = makeOutputList(op, idx, "feature_list_sparse_shapes"); err != nil {
+ scope.UpdateErr("ParseSequenceExample", err)
+ return
+ }
+ if feature_list_dense_values, idx, err = makeOutputList(op, idx, "feature_list_dense_values"); err != nil {
+ scope.UpdateErr("ParseSequenceExample", err)
+ return
+ }
+ if feature_list_dense_lengths, idx, err = makeOutputList(op, idx, "feature_list_dense_lengths"); err != nil {
+ scope.UpdateErr("ParseSequenceExample", err)
+ return
+ }
+ return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values, feature_list_dense_lengths
+}
+
// Computes the Gauss error function of `x` element-wise.
func Erf(scope *Scope, x tf.Output) (y tf.Output) {
if scope.Err() != nil {
@@ -8070,6 +9405,119 @@ func OneHot(scope *Scope, indices tf.Output, depth tf.Output, on_value tf.Output
return op.Output(0)
}
+// Computes exponential of x element-wise. \\(y = e^x\\).
+func Exp(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Exp",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// NthElementAttr is an optional argument to NthElement.
+type NthElementAttr func(optionalAttr)
+
+// NthElementReverse sets the optional reverse attribute to value.
+//
+// value: When set to True, find the nth-largest value in the vector and vice
+// versa.
+// If not specified, defaults to false
+func NthElementReverse(value bool) NthElementAttr {
+ return func(m optionalAttr) {
+ m["reverse"] = value
+ }
+}
+
+// Finds values of the `n`-th order statistic for the last dimension.
+//
+// If the input is a vector (rank-1), finds the entries which is the nth-smallest
+// value in the vector and outputs their values as scalar tensor.
+//
+// For matrices (resp. higher rank input), computes the entries which is the
+// nth-smallest value in each row (resp. vector along the last dimension). Thus,
+//
+// values.shape = input.shape[:-1]
+//
+// Arguments:
+// input: 1-D or higher with last dimension at least `n+1`.
+// n: 0-D. Position of sorted vector to select along the last dimension (along
+// each row for matrices). Valid range of n is `[0, input.shape[:-1])`
+//
+// Returns The `n`-th order statistic along each last dimensional slice.
+func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthElementAttr) (values tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "NthElement",
+ Input: []tf.Input{
+ input, n,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes the maximum along segments of a tensor.
+//
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
+//
+// This operator is similar to the unsorted segment sum operator found
+// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
+// Instead of computing the sum over segments, it computes the maximum such that:
+//
+// \\(output_i = \max_{j...} data[j...]\\) where max is over tuples `j...` such
+// that `segment_ids[j...] == i`.
+//
+// If the maximum is empty for a given segment ID `i`, it outputs the smallest
+// possible value for the specific numeric type,
+// `output[i] = numeric_limits<T>::lowest()`.
+//
+// If the given segment ID `i` is negative, then the corresponding value is
+// dropped, and will not be included in the result.
+//
+// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+// <img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt>
+// </div>
+//
+// Arguments:
+//
+// segment_ids: A tensor whose shape is a prefix of `data.shape`.END
+// }
+// out_arg {
+// name: "output"
+// description: <<END
+// Has same shape as data, except for the first `segment_ids.rank`
+// dimensions, which are replaced with a single dimension which has size
+// `num_segments`.
+//
+func UnsortedSegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "UnsortedSegmentMax",
+ Input: []tf.Input{
+ data, segment_ids, num_segments,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Transforms a vector of brain.Example protos (as strings) into typed tensors.
//
// Arguments:
@@ -8711,41 +10159,48 @@ func RandomStandardNormal(scope *Scope, shape tf.Output, dtype tf.DataType, opti
return op.Output(0)
}
-// ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl.
-type ResourceApplyFtrlAttr func(optionalAttr)
+// RandomUniformIntAttr is an optional argument to RandomUniformInt.
+type RandomUniformIntAttr func(optionalAttr)
-// ResourceApplyFtrlUseLocking sets the optional use_locking attribute to value.
+// RandomUniformIntSeed sets the optional seed attribute to value.
//
-// value: If `True`, updating of the var and accum tensors will be protected
-// by a lock; otherwise the behavior is undefined, but may exhibit less
-// contention.
-// If not specified, defaults to false
-func ResourceApplyFtrlUseLocking(value bool) ResourceApplyFtrlAttr {
+// value: If either `seed` or `seed2` are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func RandomUniformIntSeed(value int64) RandomUniformIntAttr {
return func(m optionalAttr) {
- m["use_locking"] = value
+ m["seed"] = value
}
}
-// Update '*var' according to the Ftrl-proximal scheme.
+// RandomUniformIntSeed2 sets the optional seed2 attribute to value.
//
-// accum_new = accum + grad * grad
-// linear += grad - (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var
-// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2
-// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0
-// accum = accum_new
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func RandomUniformIntSeed2(value int64) RandomUniformIntAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Outputs random integers from a uniform distribution.
+//
+// The generated values are uniform integers in the range `[minval, maxval)`.
+// The lower bound `minval` is included in the range, while the upper bound
+// `maxval` is excluded.
+//
+// The random integers are slightly biased unless `maxval - minval` is an exact
+// power of two. The bias is small for values of `maxval - minval` significantly
+// smaller than the range of the output (either `2^32` or `2^64`).
//
// Arguments:
-// var_: Should be from a Variable().
-// accum: Should be from a Variable().
-// linear: Should be from a Variable().
-// grad: The gradient.
-// lr: Scaling factor. Must be a scalar.
-// l1: L1 regulariation. Must be a scalar.
-// l2: L2 regulariation. Must be a scalar.
-// lr_power: Scaling factor. Must be a scalar.
+// shape: The shape of the output tensor.
+// minval: 0-D. Inclusive lower bound on the generated integers.
+// maxval: 0-D. Exclusive upper bound on the generated integers.
//
-// Returns the created operation.
-func ResourceApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlAttr) (o *tf.Operation) {
+// Returns A tensor of the specified shape filled with uniform random integers.
+func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf.Output, optional ...RandomUniformIntAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
@@ -8754,13 +10209,72 @@ func ResourceApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.
a(attrs)
}
opspec := tf.OpSpec{
- Type: "ResourceApplyFtrl",
+ Type: "RandomUniformInt",
Input: []tf.Input{
- var_, accum, linear, grad, lr, l1, l2, lr_power,
+ shape, minval, maxval,
},
Attrs: attrs,
}
- return scope.AddOperation(opspec)
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// FusedResizeAndPadConv2DAttr is an optional argument to FusedResizeAndPadConv2D.
+type FusedResizeAndPadConv2DAttr func(optionalAttr)
+
+// FusedResizeAndPadConv2DResizeAlignCorners sets the optional resize_align_corners attribute to value.
+//
+// value: If true, the centers of the 4 corner pixels of the input and output tensors are
+// aligned, preserving the values at the corner pixels. Defaults to false.
+// If not specified, defaults to false
+func FusedResizeAndPadConv2DResizeAlignCorners(value bool) FusedResizeAndPadConv2DAttr {
+ return func(m optionalAttr) {
+ m["resize_align_corners"] = value
+ }
+}
+
+// Performs a resize and padding as a preprocess during a convolution.
+//
+// It's often possible to do spatial transformations more efficiently as part of
+// the packing stage of a convolution, so this op allows for an optimized
+// implementation where these stages are fused together. This prevents the need to
+// write out the intermediate results as whole tensors, reducing memory pressure,
+// and we can get some latency gains by merging the transformation calculations.
+// The data_format attribute for Conv2D isn't supported by this op, and defaults to
+// 'NHWC' order.
+// Internally this op uses a single per-graph scratch buffer, which means that it
+// will block if multiple versions are being run in parallel. This is because this
+// operator is primarily an optimization to minimize memory usage.
+//
+// Arguments:
+// input: 4-D with shape `[batch, in_height, in_width, in_channels]`.
+// size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
+// new size for the images.
+// paddings: A two-column matrix specifying the padding sizes. The number of
+// rows must be the same as the rank of `input`.
+// filter: 4-D with shape
+// `[filter_height, filter_width, in_channels, out_channels]`.
+//
+// strides: 1-D of length 4. The stride of the sliding window for each dimension
+// of `input`. Must be in the same order as the dimension specified with format.
+// padding: The type of padding algorithm to use.
+func FusedResizeAndPadConv2D(scope *Scope, input tf.Output, size tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string, optional ...FusedResizeAndPadConv2DAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "FusedResizeAndPadConv2D",
+ Input: []tf.Input{
+ input, size, paddings, filter,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
}
// RandomUniformAttr is an optional argument to RandomUniform.
@@ -8817,6 +10331,58 @@ func RandomUniform(scope *Scope, shape tf.Output, dtype tf.DataType, optional ..
return op.Output(0)
}
+// ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl.
+type ResourceApplyFtrlAttr func(optionalAttr)
+
+// ResourceApplyFtrlUseLocking sets the optional use_locking attribute to value.
+//
+// value: If `True`, updating of the var and accum tensors will be protected
+// by a lock; otherwise the behavior is undefined, but may exhibit less
+// contention.
+// If not specified, defaults to false
+func ResourceApplyFtrlUseLocking(value bool) ResourceApplyFtrlAttr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Update '*var' according to the Ftrl-proximal scheme.
+//
+// accum_new = accum + grad * grad
+// linear += grad - (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var
+// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2
+// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0
+// accum = accum_new
+//
+// Arguments:
+// var_: Should be from a Variable().
+// accum: Should be from a Variable().
+// linear: Should be from a Variable().
+// grad: The gradient.
+// lr: Scaling factor. Must be a scalar.
+// l1: L1 regulariation. Must be a scalar.
+// l2: L2 regulariation. Must be a scalar.
+// lr_power: Scaling factor. Must be a scalar.
+//
+// Returns the created operation.
+func ResourceApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceApplyFtrl",
+ Input: []tf.Input{
+ var_, accum, linear, grad, lr, l1, l2, lr_power,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
// Encode audio data using the WAV file format.
//
// This operation will generate a string suitable to be saved out to create a .wav
@@ -8953,23 +10519,6 @@ func Assert(scope *Scope, condition tf.Output, data []tf.Output, optional ...Ass
return scope.AddOperation(opspec)
}
-// Broadcasts a tensor value to one or more other devices.
-func CollectiveBcastSend(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape}
- opspec := tf.OpSpec{
- Type: "CollectiveBcastSend",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Split a `SparseTensor` into `num_split` tensors along one dimension.
//
// If the `shape[split_dim]` is not an integer multiple of `num_split`. Slices
@@ -9093,6 +10642,118 @@ func ResourceSparseApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, li
return scope.AddOperation(opspec)
}
+// Calculates gains for each feature and returns the best possible split information for the feature.
+//
+// The split information is the best threshold (bucket id), gains and left/right node contributions per node for each feature.
+//
+// It is possible that not all nodes can be split on each feature. Hence, the list of possible nodes can differ between the features. Therefore, we return `node_ids_list` for each feature, containing the list of nodes that this feature can be used to split.
+//
+// In this manner, the output is the best split per features and per node, so that it needs to be combined later to produce the best split for each node (among all possible features).
+//
+// The length of output lists are all of the same length, `num_features`.
+// The output shapes are compatible in a way that the first dimension of all tensors of all lists are the same and equal to the number of possible split nodes for each feature.
+//
+// Arguments:
+// node_id_range: A Rank 1 tensor (shape=[2]) to specify the range [first, last) of node ids to process within `stats_summary_list`. The nodes are iterated between the two nodes specified by the tensor, as like `for node_id in range(node_id_range[0], node_id_range[1])` (Note that the last index node_id_range[1] is exclusive).
+// stats_summary_list: A list of Rank 3 tensor (#shape=[max_splits, bucket, 2]) for accumulated stats summary (gradient/hessian) per node per buckets for each feature. The first dimension of the tensor is the maximum number of splits, and thus not all elements of it will be used, but only the indexes specified by node_ids will be used.
+// l1: l1 regularization factor on leaf weights, per instance based.
+// l2: l2 regularization factor on leaf weights, per instance based.
+// tree_complexity: adjustment to the gain, per leaf based.
+// min_node_weight: mininum avg of hessians in a node before required for the node to be considered for splitting.
+// max_splits: the number of nodes that can be split in the whole tree. Used as a dimension of output tensors.
+//
+// Returns An output list of Rank 1 tensors indicating possible split node ids for each feature. The length of the list is num_features, but each tensor has different size as each feature provides different possible nodes. See above for details like shapes and sizes.An output list of Rank 1 tensors indicating the best gains for each feature to split for certain nodes. See above for details like shapes and sizes.An output list of Rank 1 tensors indicating the bucket id to compare with (as a threshold) for split in each node. See above for details like shapes and sizes.A list of Rank 2 tensors indicating the contribution of the left nodes when branching from parent nodes (given by the tensor element in the output node_ids_list) to the left direction by the given threshold for each feature. This value will be used to make the left node value by adding to the parent node value. Second dimension size is 1 for 1-dimensional logits, but would be larger for multi-class problems. See above for details like shapes and sizes.A list of Rank 2 tensors, with the same shape/conditions as left_node_contribs_list, but just that the value is for the right node.
+func BoostedTreesCalculateBestGainsPerFeature(scope *Scope, node_id_range tf.Output, stats_summary_list []tf.Output, l1 tf.Output, l2 tf.Output, tree_complexity tf.Output, min_node_weight tf.Output, max_splits int64) (node_ids_list []tf.Output, gains_list []tf.Output, thresholds_list []tf.Output, left_node_contribs_list []tf.Output, right_node_contribs_list []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"max_splits": max_splits}
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesCalculateBestGainsPerFeature",
+ Input: []tf.Input{
+ node_id_range, tf.OutputList(stats_summary_list), l1, l2, tree_complexity, min_node_weight,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if node_ids_list, idx, err = makeOutputList(op, idx, "node_ids_list"); err != nil {
+ scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err)
+ return
+ }
+ if gains_list, idx, err = makeOutputList(op, idx, "gains_list"); err != nil {
+ scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err)
+ return
+ }
+ if thresholds_list, idx, err = makeOutputList(op, idx, "thresholds_list"); err != nil {
+ scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err)
+ return
+ }
+ if left_node_contribs_list, idx, err = makeOutputList(op, idx, "left_node_contribs_list"); err != nil {
+ scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err)
+ return
+ }
+ if right_node_contribs_list, idx, err = makeOutputList(op, idx, "right_node_contribs_list"); err != nil {
+ scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err)
+ return
+ }
+ return node_ids_list, gains_list, thresholds_list, left_node_contribs_list, right_node_contribs_list
+}
+
+// EncodePngAttr is an optional argument to EncodePng.
+type EncodePngAttr func(optionalAttr)
+
+// EncodePngCompression sets the optional compression attribute to value.
+//
+// value: Compression level.
+// If not specified, defaults to -1
+func EncodePngCompression(value int64) EncodePngAttr {
+ return func(m optionalAttr) {
+ m["compression"] = value
+ }
+}
+
+// PNG-encode an image.
+//
+// `image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]`
+// where `channels` is:
+//
+// * 1: for grayscale.
+// * 2: for grayscale + alpha.
+// * 3: for RGB.
+// * 4: for RGBA.
+//
+// The ZLIB compression level, `compression`, can be -1 for the PNG-encoder
+// default or a value from 0 to 9. 9 is the highest compression level, generating
+// the smallest output, but is slower.
+//
+// Arguments:
+// image: 3-D with shape `[height, width, channels]`.
+//
+// Returns 0-D. PNG-encoded image.
+func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (contents tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "EncodePng",
+ Input: []tf.Input{
+ image,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// DataFormatVecPermuteAttr is an optional argument to DataFormatVecPermute.
type DataFormatVecPermuteAttr func(optionalAttr)
@@ -9143,6 +10804,29 @@ func DataFormatVecPermute(scope *Scope, x tf.Output, optional ...DataFormatVecPe
return op.Output(0)
}
+// Initializes the multi device iterator with the given dataset.
+//
+// Arguments:
+// dataset: Dataset to be iterated upon.
+// multi_device_iterator: A MultiDeviceIteratorResource.
+// max_buffer_size: The maximum size of the host side per device buffer to keep.
+//
+// Returns An int64 indicating which incarnation of the MultiDeviceIterator
+// is running.
+func MultiDeviceIteratorInit(scope *Scope, dataset tf.Output, multi_device_iterator tf.Output, max_buffer_size tf.Output) (incarnation_id tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "MultiDeviceIteratorInit",
+ Input: []tf.Input{
+ dataset, multi_device_iterator, max_buffer_size,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the gradient of `igamma(a, x)` wrt `a`.
func IgammaGradA(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) {
if scope.Err() != nil {
@@ -9188,6 +10872,49 @@ func StringToHashBucket(scope *Scope, string_tensor tf.Output, num_buckets int64
return op.Output(0)
}
+// StaticRegexReplaceAttr is an optional argument to StaticRegexReplace.
+type StaticRegexReplaceAttr func(optionalAttr)
+
+// StaticRegexReplaceReplaceGlobal sets the optional replace_global attribute to value.
+//
+// value: If True, the replacement is global, otherwise the replacement
+// is done only on the first match.
+// If not specified, defaults to true
+func StaticRegexReplaceReplaceGlobal(value bool) StaticRegexReplaceAttr {
+ return func(m optionalAttr) {
+ m["replace_global"] = value
+ }
+}
+
+// Replaces the match of pattern in input with rewrite.
+//
+// It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)
+//
+// Arguments:
+// input: The text to be processed.
+// pattern: The regular expression to match the input.
+// rewrite: The rewrite to be applied to the matched expresion.
+//
+// Returns The text after applying pattern and rewrite.
+func StaticRegexReplace(scope *Scope, input tf.Output, pattern string, rewrite string, optional ...StaticRegexReplaceAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"pattern": pattern, "rewrite": rewrite}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "StaticRegexReplace",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes gradients for the exponential linear (Elu) operation.
//
// Arguments:
@@ -9263,6 +10990,112 @@ func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value
return op.Output(0)
}
+// This op consumes a lock created by `MutexLock`.
+//
+// This op exists to consume a tensor created by `MutexLock` (other than
+// direct control dependencies). It should be the only that consumes the tensor,
+// and will raise an error if it is not. Its only purpose is to keep the
+// mutex lock tensor alive until it is consumed by this op.
+//
+// **NOTE**: This operation must run on the same device as its input. This may
+// be enforced via the `colocate_with` mechanism.
+//
+// Arguments:
+// mutex_lock: A tensor returned by `MutexLock`.
+//
+// Returns the created operation.
+func ConsumeMutexLock(scope *Scope, mutex_lock tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ConsumeMutexLock",
+ Input: []tf.Input{
+ mutex_lock,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
+// ResourceScatterNdAddAttr is an optional argument to ResourceScatterNdAdd.
+type ResourceScatterNdAddAttr func(optionalAttr)
+
+// ResourceScatterNdAddUseLocking sets the optional use_locking attribute to value.
+//
+// value: An optional bool. Defaults to True. If True, the assignment will
+// be protected by a lock; otherwise the behavior is undefined,
+// but may exhibit less contention.
+// If not specified, defaults to true
+func ResourceScatterNdAddUseLocking(value bool) ResourceScatterNdAddAttr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Adds sparse `updates` to individual values or slices within a given
+//
+// variable according to `indices`.
+//
+// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+//
+// `indices` must be integer tensor, containing indices into `ref`.
+// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+//
+// The innermost dimension of `indices` (with length `K`) corresponds to
+// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+// dimension of `ref`.
+//
+// `updates` is `Tensor` of rank `Q-1+P-K` with shape:
+//
+// ```
+// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+// ```
+//
+// For example, say we want to update 4 scattered elements to a rank-1 tensor to
+// 8 elements. In Python, that update would look like this:
+//
+// ```python
+// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8], use_resource=True)
+// indices = tf.constant([[4], [3], [1] ,[7]])
+// updates = tf.constant([9, 10, 11, 12])
+// update = tf.scatter_nd_add(ref, indices, updates)
+// with tf.Session() as sess:
+// print sess.run(update)
+// ```
+//
+// The resulting update to ref would look like this:
+//
+// [1, 12, 3, 14, 14, 6, 7, 20]
+//
+// See `tf.scatter_nd` for more details about how to make updates to
+// slices.
+//
+// Arguments:
+// ref: A resource handle. Must be from a VarHandleOp.
+// indices: A Tensor. Must be one of the following types: int32, int64.
+// A tensor of indices into ref.
+// updates: A Tensor. Must have the same type as ref. A tensor of
+// values to add to ref.
+//
+// Returns the created operation.
+func ResourceScatterNdAdd(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdAddAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceScatterNdAdd",
+ Input: []tf.Input{
+ ref, indices, updates,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
// Updates the tree ensemble by either adding a layer to the last tree being grown
//
// or by starting a new tree.
@@ -10004,68 +11837,31 @@ func ResourceScatterDiv(scope *Scope, resource tf.Output, indices tf.Output, upd
return scope.AddOperation(opspec)
}
-// ResourceScatterNdAddAttr is an optional argument to ResourceScatterNdAdd.
-type ResourceScatterNdAddAttr func(optionalAttr)
+// StatelessRandomNormalAttr is an optional argument to StatelessRandomNormal.
+type StatelessRandomNormalAttr func(optionalAttr)
-// ResourceScatterNdAddUseLocking sets the optional use_locking attribute to value.
+// StatelessRandomNormalDtype sets the optional dtype attribute to value.
//
-// value: An optional bool. Defaults to True. If True, the assignment will
-// be protected by a lock; otherwise the behavior is undefined,
-// but may exhibit less contention.
-// If not specified, defaults to true
-func ResourceScatterNdAddUseLocking(value bool) ResourceScatterNdAddAttr {
+// value: The type of the output.
+// If not specified, defaults to DT_FLOAT
+func StatelessRandomNormalDtype(value tf.DataType) StatelessRandomNormalAttr {
return func(m optionalAttr) {
- m["use_locking"] = value
+ m["dtype"] = value
}
}
-// Adds sparse `updates` to individual values or slices within a given
-//
-// variable according to `indices`.
-//
-// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
-//
-// `indices` must be integer tensor, containing indices into `ref`.
-// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
-//
-// The innermost dimension of `indices` (with length `K`) corresponds to
-// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
-// dimension of `ref`.
-//
-// `updates` is `Tensor` of rank `Q-1+P-K` with shape:
-//
-// ```
-// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
-// ```
-//
-// For example, say we want to update 4 scattered elements to a rank-1 tensor to
-// 8 elements. In Python, that update would look like this:
-//
-// ```python
-// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8], use_resource=True)
-// indices = tf.constant([[4], [3], [1] ,[7]])
-// updates = tf.constant([9, 10, 11, 12])
-// update = tf.scatter_nd_add(ref, indices, updates)
-// with tf.Session() as sess:
-// print sess.run(update)
-// ```
-//
-// The resulting update to ref would look like this:
+// Outputs deterministic pseudorandom values from a normal distribution.
//
-// [1, 12, 3, 14, 14, 6, 7, 20]
+// The generated values will have mean 0 and standard deviation 1.
//
-// See @{tf.scatter_nd} for more details about how to make updates to
-// slices.
+// The outputs are a deterministic function of `shape` and `seed`.
//
// Arguments:
-// ref: A resource handle. Must be from a VarHandleOp.
-// indices: A Tensor. Must be one of the following types: int32, int64.
-// A tensor of indices into ref.
-// updates: A Tensor. Must have the same type as ref. A tensor of
-// values to add to ref.
+// shape: The shape of the output tensor.
+// seed: 2 seeds (shape [2]).
//
-// Returns the created operation.
-func ResourceScatterNdAdd(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdAddAttr) (o *tf.Operation) {
+// Returns Random values with specified shape.
+func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomNormalAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
@@ -10074,57 +11870,93 @@ func ResourceScatterNdAdd(scope *Scope, ref tf.Output, indices tf.Output, update
a(attrs)
}
opspec := tf.OpSpec{
- Type: "ResourceScatterNdAdd",
+ Type: "StatelessRandomNormal",
Input: []tf.Input{
- ref, indices, updates,
+ shape, seed,
},
Attrs: attrs,
}
- return scope.AddOperation(opspec)
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
}
-// Mutually reduces multiple tensors of identical type and shape.
-func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, merge_op string, final_op string, subdiv_offsets []int64) (data tf.Output) {
+// Creates a sequence of numbers.
+//
+// This operation creates a sequence of numbers that begins at `start` and
+// extends by increments of `delta` up to but not including `limit`.
+//
+// For example:
+//
+// ```
+// # 'start' is 3
+// # 'limit' is 18
+// # 'delta' is 3
+// tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15]
+// ```
+//
+// Arguments:
+// start: 0-D (scalar). First entry in the sequence.
+// limit: 0-D (scalar). Upper limit of sequence, exclusive.
+// delta: 0-D (scalar). Optional. Default is 1. Number that increments `start`.
+//
+// Returns 1-D.
+func Range(scope *Scope, start tf.Output, limit tf.Output, delta tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "merge_op": merge_op, "final_op": final_op, "subdiv_offsets": subdiv_offsets}
opspec := tf.OpSpec{
- Type: "CollectiveReduce",
+ Type: "Range",
Input: []tf.Input{
- input,
+ start, limit, delta,
},
- Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// StatelessRandomNormalAttr is an optional argument to StatelessRandomNormal.
-type StatelessRandomNormalAttr func(optionalAttr)
+// ResourceApplyMomentumAttr is an optional argument to ResourceApplyMomentum.
+type ResourceApplyMomentumAttr func(optionalAttr)
-// StatelessRandomNormalDtype sets the optional dtype attribute to value.
+// ResourceApplyMomentumUseLocking sets the optional use_locking attribute to value.
//
-// value: The type of the output.
-// If not specified, defaults to DT_FLOAT
-func StatelessRandomNormalDtype(value tf.DataType) StatelessRandomNormalAttr {
+// value: If `True`, updating of the var and accum tensors will be protected
+// by a lock; otherwise the behavior is undefined, but may exhibit less
+// contention.
+// If not specified, defaults to false
+func ResourceApplyMomentumUseLocking(value bool) ResourceApplyMomentumAttr {
return func(m optionalAttr) {
- m["dtype"] = value
+ m["use_locking"] = value
}
}
-// Outputs deterministic pseudorandom values from a normal distribution.
+// ResourceApplyMomentumUseNesterov sets the optional use_nesterov attribute to value.
//
-// The generated values will have mean 0 and standard deviation 1.
+// value: If `True`, the tensor passed to compute grad will be
+// var - lr * momentum * accum, so in the end, the var you get is actually
+// var - lr * momentum * accum.
+// If not specified, defaults to false
+func ResourceApplyMomentumUseNesterov(value bool) ResourceApplyMomentumAttr {
+ return func(m optionalAttr) {
+ m["use_nesterov"] = value
+ }
+}
+
+// Update '*var' according to the momentum scheme. Set use_nesterov = True if you
//
-// The outputs are a deterministic function of `shape` and `seed`.
+// want to use Nesterov momentum.
+//
+// accum = accum * momentum + grad
+// var -= lr * accum
//
// Arguments:
-// shape: The shape of the output tensor.
-// seed: 2 seeds (shape [2]).
+// var_: Should be from a Variable().
+// accum: Should be from a Variable().
+// lr: Scaling factor. Must be a scalar.
+// grad: The gradient.
+// momentum: Momentum. Must be a scalar.
//
-// Returns Random values with specified shape.
-func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomNormalAttr) (output tf.Output) {
+// Returns the created operation.
+func ResourceApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, momentum tf.Output, optional ...ResourceApplyMomentumAttr) (o *tf.Operation) {
if scope.Err() != nil {
return
}
@@ -10133,12 +11965,54 @@ func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, option
a(attrs)
}
opspec := tf.OpSpec{
- Type: "StatelessRandomNormal",
+ Type: "ResourceApplyMomentum",
Input: []tf.Input{
- shape, seed,
+ var_, accum, lr, grad, momentum,
},
Attrs: attrs,
}
+ return scope.AddOperation(opspec)
+}
+
+// Exits the current frame to its parent frame.
+//
+// Exit makes its input `data` available to the parent frame.
+//
+// Arguments:
+// data: The tensor to be made available to the parent frame.
+//
+// Returns The same tensor as `data`.
+func Exit(scope *Scope, data tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Exit",
+ Input: []tf.Input{
+ data,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Produce a string tensor that encodes the state of a Reader.
+//
+// Not all Readers support being serialized, so this can produce an
+// Unimplemented error.
+//
+// Arguments:
+// reader_handle: Handle to a Reader.
+func ReaderSerializeStateV2(scope *Scope, reader_handle tf.Output) (state tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ReaderSerializeStateV2",
+ Input: []tf.Input{
+ reader_handle,
+ },
+ }
op := scope.AddOperation(opspec)
return op.Output(0)
}
@@ -10276,68 +12150,6 @@ func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (o
return op.Output(0)
}
-// StringSplitV2Attr is an optional argument to StringSplitV2.
-type StringSplitV2Attr func(optionalAttr)
-
-// StringSplitV2Maxsplit sets the optional maxsplit attribute to value.
-//
-// value: An `int`. If `maxsplit > 0`, limit of the split of the result.
-// If not specified, defaults to -1
-func StringSplitV2Maxsplit(value int64) StringSplitV2Attr {
- return func(m optionalAttr) {
- m["maxsplit"] = value
- }
-}
-
-// Split elements of `source` based on `sep` into a `SparseTensor`.
-//
-// Let N be the size of source (typically N will be the batch size). Split each
-// element of `source` based on `sep` and return a `SparseTensor`
-// containing the split tokens. Empty tokens are ignored.
-//
-// For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c',
-// then the output will be
-// ```
-// st.indices = [0, 0;
-// 0, 1;
-// 1, 0;
-// 1, 1;
-// 1, 2]
-// st.shape = [2, 3]
-// st.values = ['hello', 'world', 'a', 'b', 'c']
-// ```
-//
-// If `sep` is given, consecutive delimiters are not grouped together and are
-// deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and
-// sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
-// string, consecutive whitespace are regarded as a single separator, and the
-// result will contain no empty strings at the startor end if the string has
-// leading or trailing whitespace.
-//
-// Note that the above mentioned behavior matches python's str.split.
-//
-// Arguments:
-// input: `1-D` string `Tensor`, the strings to split.
-// sep: `0-D` string `Tensor`, the delimiter character.
-func StringSplitV2(scope *Scope, input tf.Output, sep tf.Output, optional ...StringSplitV2Attr) (indices tf.Output, values tf.Output, shape tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "StringSplitV2",
- Input: []tf.Input{
- input, sep,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// MaxPoolAttr is an optional argument to MaxPool.
type MaxPoolAttr func(optionalAttr)
@@ -10818,6 +12630,51 @@ func Conj(scope *Scope, input tf.Output) (output tf.Output) {
return op.Output(0)
}
+// ProdAttr is an optional argument to Prod.
+type ProdAttr func(optionalAttr)
+
+// ProdKeepDims sets the optional keep_dims attribute to value.
+//
+// value: If true, retain reduced dimensions with length 1.
+// If not specified, defaults to false
+func ProdKeepDims(value bool) ProdAttr {
+ return func(m optionalAttr) {
+ m["keep_dims"] = value
+ }
+}
+
+// Computes the product of elements across dimensions of a tensor.
+//
+// Reduces `input` along the dimensions given in `axis`. Unless
+// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+// `axis`. If `keep_dims` is true, the reduced dimensions are
+// retained with length 1.
+//
+// Arguments:
+// input: The tensor to reduce.
+// axis: The dimensions to reduce. Must be in the range
+// `[-rank(input), rank(input))`.
+//
+// Returns The reduced tensor.
+func Prod(scope *Scope, input tf.Output, axis tf.Output, optional ...ProdAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Prod",
+ Input: []tf.Input{
+ input, axis,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// ResizeBilinearAttr is an optional argument to ResizeBilinear.
type ResizeBilinearAttr func(optionalAttr)
@@ -10862,21 +12719,6 @@ func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ...
return op.Output(0)
}
-// Computes softsign: `features / (abs(features) + 1)`.
-func Softsign(scope *Scope, features tf.Output) (activations tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Softsign",
- Input: []tf.Input{
- features,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Creates a TensorList which, when stacked, has the value of `tensor`.
//
// Each tensor in the result list corresponds to one row of the input tensor.
@@ -10897,81 +12739,6 @@ func TensorListFromTensor(scope *Scope, tensor tf.Output, element_shape tf.Outpu
return op.Output(0)
}
-// GenerateVocabRemappingAttr is an optional argument to GenerateVocabRemapping.
-type GenerateVocabRemappingAttr func(optionalAttr)
-
-// GenerateVocabRemappingOldVocabSize sets the optional old_vocab_size attribute to value.
-//
-// value: Number of entries in the old vocab file to consider. If -1,
-// use the entire old vocabulary.
-// If not specified, defaults to -1
-//
-// REQUIRES: value >= -1
-func GenerateVocabRemappingOldVocabSize(value int64) GenerateVocabRemappingAttr {
- return func(m optionalAttr) {
- m["old_vocab_size"] = value
- }
-}
-
-// Given a path to new and old vocabulary files, returns a remapping Tensor of
-//
-// length `num_new_vocab`, where `remapping[i]` contains the row number in the old
-// vocabulary that corresponds to row `i` in the new vocabulary (starting at line
-// `new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i`
-// in the new vocabulary is not in the old vocabulary. The old vocabulary is
-// constrained to the first `old_vocab_size` entries if `old_vocab_size` is not the
-// default value of -1.
-//
-// `num_vocab_offset` enables
-// use in the partitioned variable case, and should generally be set through
-// examining partitioning info. The format of the files should be a text file,
-// with each line containing a single entity within the vocabulary.
-//
-// For example, with `new_vocab_file` a text file containing each of the following
-// elements on a single line: `[f0, f1, f2, f3]`, old_vocab_file = [f1, f0, f3],
-// `num_new_vocab = 3, new_vocab_offset = 1`, the returned remapping would be
-// `[0, -1, 2]`.
-//
-// The op also returns a count of how many entries in the new vocabulary
-// were present in the old vocabulary, which is used to calculate the number of
-// values to initialize in a weight matrix remapping
-//
-// This functionality can be used to remap both row vocabularies (typically,
-// features) and column vocabularies (typically, classes) from TensorFlow
-// checkpoints. Note that the partitioning logic relies on contiguous vocabularies
-// corresponding to div-partitioned variables. Moreover, the underlying remapping
-// uses an IndexTable (as opposed to an inexact CuckooTable), so client code should
-// use the corresponding index_table_from_file() as the FeatureColumn framework
-// does (as opposed to tf.feature_to_id(), which uses a CuckooTable).
-//
-// Arguments:
-// new_vocab_file: Path to the new vocab file.
-// old_vocab_file: Path to the old vocab file.
-// new_vocab_offset: How many entries into the new vocab file to start reading.
-// num_new_vocab: Number of entries in the new vocab file to remap.
-//
-// Returns A Tensor of length num_new_vocab where the element at index i
-// is equal to the old ID that maps to the new ID i. This element is -1 for any
-// new ID that is not found in the old vocabulary.Number of new vocab entries found in old vocab.
-func GenerateVocabRemapping(scope *Scope, new_vocab_file tf.Output, old_vocab_file tf.Output, new_vocab_offset int64, num_new_vocab int64, optional ...GenerateVocabRemappingAttr) (remapping tf.Output, num_present tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"new_vocab_offset": new_vocab_offset, "num_new_vocab": num_new_vocab}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "GenerateVocabRemapping",
- Input: []tf.Input{
- new_vocab_file, old_vocab_file,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
// Assigns sparse updates to the variable referenced by `resource`.
//
// This operation computes
@@ -11178,65 +12945,6 @@ func StageClear(scope *Scope, dtypes []tf.DataType, optional ...StageClearAttr)
return scope.AddOperation(opspec)
}
-// ComputeAccidentalHitsAttr is an optional argument to ComputeAccidentalHits.
-type ComputeAccidentalHitsAttr func(optionalAttr)
-
-// ComputeAccidentalHitsSeed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func ComputeAccidentalHitsSeed(value int64) ComputeAccidentalHitsAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// ComputeAccidentalHitsSeed2 sets the optional seed2 attribute to value.
-//
-// value: An second seed to avoid seed collision.
-// If not specified, defaults to 0
-func ComputeAccidentalHitsSeed2(value int64) ComputeAccidentalHitsAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Computes the ids of the positions in sampled_candidates that match true_labels.
-//
-// When doing log-odds NCE, the result of this op should be passed through a
-// SparseToDense op, then added to the logits of the sampled candidates. This has
-// the effect of 'removing' the sampled labels that match the true labels by
-// making the classifier sure that they are sampled labels.
-//
-// Arguments:
-// true_classes: The true_classes output of UnpackSparseLabels.
-// sampled_candidates: The sampled_candidates output of CandidateSampler.
-// num_true: Number of true labels per context.
-//
-// Returns A vector of indices corresponding to rows of true_candidates.A vector of IDs of positions in sampled_candidates that match a true_label
-// for the row with the corresponding index in indices.A vector of the same length as indices and ids, in which each element
-// is -FLOAT_MAX.
-func ComputeAccidentalHits(scope *Scope, true_classes tf.Output, sampled_candidates tf.Output, num_true int64, optional ...ComputeAccidentalHitsAttr) (indices tf.Output, ids tf.Output, weights tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_true": num_true}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ComputeAccidentalHits",
- Input: []tf.Input{
- true_classes, sampled_candidates,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// QuantizedRelu6Attr is an optional argument to QuantizedRelu6.
type QuantizedRelu6Attr func(optionalAttr)
@@ -11365,36 +13073,43 @@ func FixedLengthRecordReaderV2(scope *Scope, record_bytes int64, optional ...Fix
return op.Output(0)
}
-// The gradient operator for the SparseAdd op.
+// StringLengthAttr is an optional argument to StringLength.
+type StringLengthAttr func(optionalAttr)
+
+// StringLengthUnit sets the optional unit attribute to value.
+// If not specified, defaults to "BYTE"
+func StringLengthUnit(value string) StringLengthAttr {
+ return func(m optionalAttr) {
+ m["unit"] = value
+ }
+}
+
+// String lengths of `input`.
//
-// The SparseAdd op calculates A + B, where A, B, and the sum are all represented
-// as `SparseTensor` objects. This op takes in the upstream gradient w.r.t.
-// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty
-// values of A and B.
+// Computes the length of each string given in the input tensor.
//
// Arguments:
-// backprop_val_grad: 1-D with shape `[nnz(sum)]`. The gradient with respect to
-// the non-empty values of the sum.
-// a_indices: 2-D. The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`.
-// b_indices: 2-D. The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`.
-// sum_indices: 2-D. The `indices` of the sum `SparseTensor`, size
-// `[nnz(sum), ndims]`.
+// input: The string for which to compute the length.
//
-// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the
-// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the
-// non-empty values of B.
-func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) {
+// Returns Integer tensor that has the same shape as `input`. The output contains the
+// element-wise string lengths of `input`.
+func StringLength(scope *Scope, input tf.Output, optional ...StringLengthAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "SparseAddGrad",
+ Type: "StringLength",
Input: []tf.Input{
- backprop_val_grad, a_indices, b_indices, sum_indices,
+ input,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
+ return op.Output(0)
}
// Converts each string in the input Tensor to its hash mod by a number of buckets.
@@ -11747,7 +13462,7 @@ func ResourceScatterNdUpdateUseLocking(value bool) ResourceScatterNdUpdateAttr {
//
// [1, 11, 3, 10, 9, 6, 7, 12]
//
-// See @{tf.scatter_nd} for more details about how to make updates to
+// See `tf.scatter_nd` for more details about how to make updates to
// slices.
//
// Arguments:
@@ -11776,6 +13491,26 @@ func ResourceScatterNdUpdate(scope *Scope, ref tf.Output, indices tf.Output, upd
return scope.AddOperation(opspec)
}
+// Produces a string handle for the given MultiDeviceIterator.
+//
+// Arguments:
+// multi_device_iterator: A MultiDeviceIterator resource.
+//
+// Returns A string representing the resource.
+func MultiDeviceIteratorToStringHandle(scope *Scope, multi_device_iterator tf.Output) (string_handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "MultiDeviceIteratorToStringHandle",
+ Input: []tf.Input{
+ multi_device_iterator,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Applies softmax to a batched N-D `SparseTensor`.
//
// The inputs represent an N-D SparseTensor with logical shape `[..., B, C]`
@@ -12006,6 +13741,27 @@ func DataFormatDimMap(scope *Scope, x tf.Output, optional ...DataFormatDimMapAtt
return op.Output(0)
}
+// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics.
+//
+// Arguments:
+// tree_ensemble_handle: Handle to the tree ensemble.
+//
+// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest
+// layer.
+func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesGetEnsembleStates",
+ Input: []tf.Input{
+ tree_ensemble_handle,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
+}
+
// ResourceApplyPowerSignAttr is an optional argument to ResourceApplyPowerSign.
type ResourceApplyPowerSignAttr func(optionalAttr)
@@ -12230,10 +13986,188 @@ func MutexLock(scope *Scope, mutex tf.Output) (mutex_lock tf.Output) {
return op.Output(0)
}
+// StringFormatAttr is an optional argument to StringFormat.
+type StringFormatAttr func(optionalAttr)
+
+// StringFormatTemplate sets the optional template attribute to value.
+//
+// value: A string, the template to format tensor summaries into.
+// If not specified, defaults to "%s"
+func StringFormatTemplate(value string) StringFormatAttr {
+ return func(m optionalAttr) {
+ m["template"] = value
+ }
+}
+
+// StringFormatPlaceholder sets the optional placeholder attribute to value.
+//
+// value: A string, at each placeholder in the template a subsequent tensor summary will be inserted.
+// If not specified, defaults to "%s"
+func StringFormatPlaceholder(value string) StringFormatAttr {
+ return func(m optionalAttr) {
+ m["placeholder"] = value
+ }
+}
+
+// StringFormatSummarize sets the optional summarize attribute to value.
+//
+// value: When formatting the tensor summaries print the first and last summarize entries of each tensor dimension.
+// If not specified, defaults to 3
+func StringFormatSummarize(value int64) StringFormatAttr {
+ return func(m optionalAttr) {
+ m["summarize"] = value
+ }
+}
+
+// Formats a string template using a list of tensors.
+//
+// Formats a string template using a list of tensors, pretty-printing tensor summaries.
+//
+// Arguments:
+// inputs: The list of tensors to format into the placeholder string.
+//
+// Returns = The resulting string scalar.
+func StringFormat(scope *Scope, inputs []tf.Output, optional ...StringFormatAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "StringFormat",
+ Input: []tf.Input{
+ tf.OutputList(inputs),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ShapeAttr is an optional argument to Shape.
+type ShapeAttr func(optionalAttr)
+
+// ShapeOutType sets the optional out_type attribute to value.
+// If not specified, defaults to DT_INT32
+func ShapeOutType(value tf.DataType) ShapeAttr {
+ return func(m optionalAttr) {
+ m["out_type"] = value
+ }
+}
+
+// Returns the shape of a tensor.
+//
+// This operation returns a 1-D integer tensor representing the shape of `input`.
+//
+// For example:
+//
+// ```
+// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
+// shape(t) ==> [2, 2, 3]
+// ```
+func Shape(scope *Scope, input tf.Output, optional ...ShapeAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Shape",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes the power of one value to another.
+//
+// Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for
+// corresponding elements in `x` and `y`. For example:
+//
+// ```
+// # tensor 'x' is [[2, 2]], [3, 3]]
+// # tensor 'y' is [[8, 16], [2, 3]]
+// tf.pow(x, y) ==> [[256, 65536], [9, 27]]
+// ```
+func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Pow",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes fingerprints of the input strings.
+//
+// Arguments:
+// input: vector of strings to compute fingerprints on.
+//
+// Returns a (N,2) shaped matrix where N is the number of elements in the input
+// vector. Each row contains the low and high parts of the fingerprint.
+func SdcaFprint(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SdcaFprint",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// The gradient operator for the SparseAdd op.
+//
+// The SparseAdd op calculates A + B, where A, B, and the sum are all represented
+// as `SparseTensor` objects. This op takes in the upstream gradient w.r.t.
+// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty
+// values of A and B.
+//
+// Arguments:
+// backprop_val_grad: 1-D with shape `[nnz(sum)]`. The gradient with respect to
+// the non-empty values of the sum.
+// a_indices: 2-D. The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`.
+// b_indices: 2-D. The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`.
+// sum_indices: 2-D. The `indices` of the sum `SparseTensor`, size
+// `[nnz(sum), ndims]`.
+//
+// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the
+// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the
+// non-empty values of B.
+func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseAddGrad",
+ Input: []tf.Input{
+ backprop_val_grad, a_indices, b_indices, sum_indices,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
// Computes the mean along segments of a tensor.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
// Computes a tensor such that
// \\(output_i = \frac{\sum_j data_j}{N}\\) where `mean` is
@@ -12248,7 +14182,7 @@ func MutexLock(scope *Scope, mutex tf.Output) (mutex_lock tf.Output) {
//
// Arguments:
//
-// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
// first dimension. Values should be sorted and can be repeated.
//
// Returns Has same shape as data, except for dimension 0 which
@@ -12367,7 +14301,7 @@ func BatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, o
//
// Arguments:
// input: A string tensor of the text to be processed.
-// pattern: A 1-D string tensor of the regular expression to match the input.
+// pattern: A scalar string tensor containing the regular expression to match the input.
//
// Returns A bool tensor with the same shape as `input`.
func RegexFullMatch(scope *Scope, input tf.Output, pattern tf.Output) (output tf.Output) {
@@ -12421,6 +14355,79 @@ func InTopKV2(scope *Scope, predictions tf.Output, targets tf.Output, k tf.Outpu
return op.Output(0)
}
+// RandomPoissonV2Attr is an optional argument to RandomPoissonV2.
+type RandomPoissonV2Attr func(optionalAttr)
+
+// RandomPoissonV2Seed sets the optional seed attribute to value.
+//
+// value: If either `seed` or `seed2` are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func RandomPoissonV2Seed(value int64) RandomPoissonV2Attr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// RandomPoissonV2Seed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func RandomPoissonV2Seed2(value int64) RandomPoissonV2Attr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// RandomPoissonV2Dtype sets the optional dtype attribute to value.
+// If not specified, defaults to DT_INT64
+func RandomPoissonV2Dtype(value tf.DataType) RandomPoissonV2Attr {
+ return func(m optionalAttr) {
+ m["dtype"] = value
+ }
+}
+
+// Outputs random values from the Poisson distribution(s) described by rate.
+//
+// This op uses two algorithms, depending on rate. If rate >= 10, then
+// the algorithm by Hormann is used to acquire samples via
+// transformation-rejection.
+// See http://www.sciencedirect.com/science/article/pii/0167668793909974.
+//
+// Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform
+// random variables.
+// See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer
+// Programming, Volume 2. Addison Wesley
+//
+// Arguments:
+// shape: 1-D integer tensor. Shape of independent samples to draw from each
+// distribution described by the shape parameters given in rate.
+// rate: A tensor in which each scalar is a "rate" parameter describing the
+// associated poisson distribution.
+//
+// Returns A tensor with shape `shape + shape(rate)`. Each slice
+// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
+// `rate[i0, i1, ...iN]`.
+func RandomPoissonV2(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonV2Attr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "RandomPoissonV2",
+ Input: []tf.Input{
+ shape, rate,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// DecodeAndCropJpegAttr is an optional argument to DecodeAndCropJpeg.
type DecodeAndCropJpegAttr func(optionalAttr)
@@ -12537,78 +14544,6 @@ func DecodeAndCropJpeg(scope *Scope, contents tf.Output, crop_window tf.Output,
return op.Output(0)
}
-// AllCandidateSamplerAttr is an optional argument to AllCandidateSampler.
-type AllCandidateSamplerAttr func(optionalAttr)
-
-// AllCandidateSamplerSeed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func AllCandidateSamplerSeed(value int64) AllCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// AllCandidateSamplerSeed2 sets the optional seed2 attribute to value.
-//
-// value: An second seed to avoid seed collision.
-// If not specified, defaults to 0
-func AllCandidateSamplerSeed2(value int64) AllCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Generates labels for candidate sampling with a learned unigram distribution.
-//
-// See explanations of candidate sampling and the data formats at
-// go/candidate-sampling.
-//
-// For each batch, this op picks a single set of sampled candidate labels.
-//
-// The advantages of sampling candidates per-batch are simplicity and the
-// possibility of efficient dense matrix multiplication. The disadvantage is that
-// the sampled candidates must be chosen independently of the context and of the
-// true labels.
-//
-// Arguments:
-// true_classes: A batch_size * num_true matrix, in which each row contains the
-// IDs of the num_true target_classes in the corresponding original label.
-// num_true: Number of true labels per context.
-// num_sampled: Number of candidates to produce.
-// unique: If unique is true, we sample with rejection, so that all sampled
-// candidates in a batch are unique. This requires some approximation to
-// estimate the post-rejection sampling probabilities.
-//
-// Returns A vector of length num_sampled, in which each element is
-// the ID of a sampled candidate.A batch_size * num_true matrix, representing
-// the number of times each candidate is expected to occur in a batch
-// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
-// candidate representing the number of times the candidate is expected
-// to occur in a batch of sampled candidates. If unique=true, then this is a
-// probability.
-func AllCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, optional ...AllCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "AllCandidateSampler",
- Input: []tf.Input{
- true_classes,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// Adds two `SparseTensor` objects to produce another `SparseTensor`.
//
// The input `SparseTensor` objects' indices are assumed ordered in standard
@@ -13419,6 +15354,78 @@ func StringToHashBucketFast(scope *Scope, input tf.Output, num_buckets int64) (o
return op.Output(0)
}
+// Returns the last element of the input list as well as a list with all but that element.
+//
+// Fails if the list is empty.
+//
+// input_handle: the input list
+// tensor: the withdrawn last element of the list
+// element_dtype: the type of elements in the list
+// element_shape: the shape of the output tensor
+func TensorListPopBack(scope *Scope, input_handle tf.Output, element_dtype tf.DataType) (output_handle tf.Output, tensor tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"element_dtype": element_dtype}
+ opspec := tf.OpSpec{
+ Type: "TensorListPopBack",
+ Input: []tf.Input{
+ input_handle,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
+// MaxPoolGradGradAttr is an optional argument to MaxPoolGradGrad.
+type MaxPoolGradGradAttr func(optionalAttr)
+
+// MaxPoolGradGradDataFormat sets the optional data_format attribute to value.
+//
+// value: Specify the data format of the input and output data. With the
+// default format "NHWC", the data is stored in the order of:
+// [batch, in_height, in_width, in_channels].
+// Alternatively, the format could be "NCHW", the data storage order of:
+// [batch, in_channels, in_height, in_width].
+// If not specified, defaults to "NHWC"
+func MaxPoolGradGradDataFormat(value string) MaxPoolGradGradAttr {
+ return func(m optionalAttr) {
+ m["data_format"] = value
+ }
+}
+
+// Computes second-order gradients of the maxpooling function.
+//
+// Arguments:
+// orig_input: The original input tensor.
+// orig_output: The original output tensor.
+// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`.
+// ksize: The size of the window for each dimension of the input tensor.
+// strides: The stride of the sliding window for each dimension of the
+// input tensor.
+// padding: The type of padding algorithm to use.
+//
+// Returns Gradients of gradients w.r.t. the input to `max_pool`.
+func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "MaxPoolGradGrad",
+ Input: []tf.Input{
+ orig_input, orig_output, grad,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3.
type TensorArrayGatherV3Attr func(optionalAttr)
@@ -13465,33 +15472,6 @@ func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow
return op.Output(0)
}
-// This op consumes a lock created by `MutexLock`.
-//
-// This op exists to consume a tensor created by `MutexLock` (other than
-// direct control dependencies). It should be the only that consumes the tensor,
-// and will raise an error if it is not. Its only purpose is to keep the
-// mutex lock tensor alive until it is consumed by this op.
-//
-// **NOTE**: This operation must run on the same device as its input. This may
-// be enforced via the `colocate_with` mechanism.
-//
-// Arguments:
-// mutex_lock: A tensor returned by `MutexLock`.
-//
-// Returns the created operation.
-func ConsumeMutexLock(scope *Scope, mutex_lock tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ConsumeMutexLock",
- Input: []tf.Input{
- mutex_lock,
- },
- }
- return scope.AddOperation(opspec)
-}
-
// Returns x / y element-wise for integer types.
//
// Truncation designates that negative numbers will round fractional quantities
@@ -14443,6 +16423,25 @@ func ResourceApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf
return scope.AddOperation(opspec)
}
+// Returns 0 if the denominator is zero.
+//
+//
+// *NOTE*: `DivNoNan` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func DivNoNan(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "DivNoNan",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the gradient for the sqrt of `x` wrt its input.
//
// Specifically, `grad = dy * 0.5 / y`, where `y = sqrt(x)`, and `dy`
@@ -14543,79 +16542,6 @@ func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...Ra
return op.Output(0)
}
-// LogUniformCandidateSamplerAttr is an optional argument to LogUniformCandidateSampler.
-type LogUniformCandidateSamplerAttr func(optionalAttr)
-
-// LogUniformCandidateSamplerSeed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func LogUniformCandidateSamplerSeed(value int64) LogUniformCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// LogUniformCandidateSamplerSeed2 sets the optional seed2 attribute to value.
-//
-// value: An second seed to avoid seed collision.
-// If not specified, defaults to 0
-func LogUniformCandidateSamplerSeed2(value int64) LogUniformCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Generates labels for candidate sampling with a log-uniform distribution.
-//
-// See explanations of candidate sampling and the data formats at
-// go/candidate-sampling.
-//
-// For each batch, this op picks a single set of sampled candidate labels.
-//
-// The advantages of sampling candidates per-batch are simplicity and the
-// possibility of efficient dense matrix multiplication. The disadvantage is that
-// the sampled candidates must be chosen independently of the context and of the
-// true labels.
-//
-// Arguments:
-// true_classes: A batch_size * num_true matrix, in which each row contains the
-// IDs of the num_true target_classes in the corresponding original label.
-// num_true: Number of true labels per context.
-// num_sampled: Number of candidates to randomly sample.
-// unique: If unique is true, we sample with rejection, so that all sampled
-// candidates in a batch are unique. This requires some approximation to
-// estimate the post-rejection sampling probabilities.
-// range_max: The sampler will sample integers from the interval [0, range_max).
-//
-// Returns A vector of length num_sampled, in which each element is
-// the ID of a sampled candidate.A batch_size * num_true matrix, representing
-// the number of times each candidate is expected to occur in a batch
-// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
-// candidate representing the number of times the candidate is expected
-// to occur in a batch of sampled candidates. If unique=true, then this is a
-// probability.
-func LogUniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LogUniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "LogUniformCandidateSampler",
- Input: []tf.Input{
- true_classes,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// Returns the max of x and y (i.e. x > y ? x : y) element-wise.
//
// *NOTE*: `Maximum` supports broadcasting. More about broadcasting
@@ -14901,109 +16827,6 @@ func Zeta(scope *Scope, x tf.Output, q tf.Output) (z tf.Output) {
return op.Output(0)
}
-// ProdAttr is an optional argument to Prod.
-type ProdAttr func(optionalAttr)
-
-// ProdKeepDims sets the optional keep_dims attribute to value.
-//
-// value: If true, retain reduced dimensions with length 1.
-// If not specified, defaults to false
-func ProdKeepDims(value bool) ProdAttr {
- return func(m optionalAttr) {
- m["keep_dims"] = value
- }
-}
-
-// Computes the product of elements across dimensions of a tensor.
-//
-// Reduces `input` along the dimensions given in `axis`. Unless
-// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
-// `axis`. If `keep_dims` is true, the reduced dimensions are
-// retained with length 1.
-//
-// Arguments:
-// input: The tensor to reduce.
-// axis: The dimensions to reduce. Must be in the range
-// `[-rank(input), rank(input))`.
-//
-// Returns The reduced tensor.
-func Prod(scope *Scope, input tf.Output, axis tf.Output, optional ...ProdAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Prod",
- Input: []tf.Input{
- input, axis,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// FusedResizeAndPadConv2DAttr is an optional argument to FusedResizeAndPadConv2D.
-type FusedResizeAndPadConv2DAttr func(optionalAttr)
-
-// FusedResizeAndPadConv2DResizeAlignCorners sets the optional resize_align_corners attribute to value.
-//
-// value: If true, the centers of the 4 corner pixels of the input and output tensors are
-// aligned, preserving the values at the corner pixels. Defaults to false.
-// If not specified, defaults to false
-func FusedResizeAndPadConv2DResizeAlignCorners(value bool) FusedResizeAndPadConv2DAttr {
- return func(m optionalAttr) {
- m["resize_align_corners"] = value
- }
-}
-
-// Performs a resize and padding as a preprocess during a convolution.
-//
-// It's often possible to do spatial transformations more efficiently as part of
-// the packing stage of a convolution, so this op allows for an optimized
-// implementation where these stages are fused together. This prevents the need to
-// write out the intermediate results as whole tensors, reducing memory pressure,
-// and we can get some latency gains by merging the transformation calculations.
-// The data_format attribute for Conv2D isn't supported by this op, and defaults to
-// 'NHWC' order.
-// Internally this op uses a single per-graph scratch buffer, which means that it
-// will block if multiple versions are being run in parallel. This is because this
-// operator is primarily an optimization to minimize memory usage.
-//
-// Arguments:
-// input: 4-D with shape `[batch, in_height, in_width, in_channels]`.
-// size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
-// new size for the images.
-// paddings: A two-column matrix specifying the padding sizes. The number of
-// rows must be the same as the rank of `input`.
-// filter: 4-D with shape
-// `[filter_height, filter_width, in_channels, out_channels]`.
-//
-// strides: 1-D of length 4. The stride of the sliding window for each dimension
-// of `input`. Must be in the same order as the dimension specified with format.
-// padding: The type of padding algorithm to use.
-func FusedResizeAndPadConv2D(scope *Scope, input tf.Output, size tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string, optional ...FusedResizeAndPadConv2DAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "FusedResizeAndPadConv2D",
- Input: []tf.Input{
- input, size, paddings, filter,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Returns a list of tensors with the same shapes and contents as the input
//
// tensors.
@@ -15350,6 +17173,36 @@ func BytesProducedStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Out
return op.Output(0)
}
+// Check if the input matches the regex pattern.
+//
+// The input is a string tensor of any shape. The pattern is the
+// regular expression to be matched with every element of the input tensor.
+// The boolean values (True or False) of the output tensor indicate
+// if the input matches the regex pattern provided.
+//
+// The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)
+//
+// Arguments:
+// input: A string tensor of the text to be processed.
+// pattern: The regular expression to match the input.
+//
+// Returns A bool tensor with the same shape as `input`.
+func StaticRegexFullMatch(scope *Scope, input tf.Output, pattern string) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"pattern": pattern}
+ opspec := tf.OpSpec{
+ Type: "StaticRegexFullMatch",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// ResourceSparseApplyProximalGradientDescentAttr is an optional argument to ResourceSparseApplyProximalGradientDescent.
type ResourceSparseApplyProximalGradientDescentAttr func(optionalAttr)
@@ -15849,6 +17702,64 @@ func CudnnRNNBackprop(scope *Scope, input tf.Output, input_h tf.Output, input_c
return op.Output(0), op.Output(1), op.Output(2), op.Output(3)
}
+// UpperBoundAttr is an optional argument to UpperBound.
+type UpperBoundAttr func(optionalAttr)
+
+// UpperBoundOutType sets the optional out_type attribute to value.
+// If not specified, defaults to DT_INT32
+func UpperBoundOutType(value tf.DataType) UpperBoundAttr {
+ return func(m optionalAttr) {
+ m["out_type"] = value
+ }
+}
+
+// Applies upper_bound(sorted_search_values, values) along each row.
+//
+// Each set of rows with the same index in (sorted_inputs, values) is treated
+// independently. The resulting row is the equivalent of calling
+// `np.searchsorted(sorted_inputs, values, side='right')`.
+//
+// The result is not a global index to the entire
+// `Tensor`, but rather just the index in the last dimension.
+//
+// A 2-D example:
+// sorted_sequence = [[0, 3, 9, 9, 10],
+// [1, 2, 3, 4, 5]]
+// values = [[2, 4, 9],
+// [0, 2, 6]]
+//
+// result = UpperBound(sorted_sequence, values)
+//
+// result == [[1, 2, 4],
+// [0, 2, 5]]
+//
+// Arguments:
+// sorted_inputs: 2-D Tensor where each row is ordered.
+// values: 2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains
+// the values that will be searched for in `sorted_search_values`.
+//
+// Returns A `Tensor` with the same shape as `values`. It contains the last scalar index
+// into the last dimension where values can be inserted without changing the
+// ordered property.
+func UpperBound(scope *Scope, sorted_inputs tf.Output, values tf.Output, optional ...UpperBoundAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "UpperBound",
+ Input: []tf.Input{
+ sorted_inputs, values,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// FractionalMaxPoolGradAttr is an optional argument to FractionalMaxPoolGrad.
type FractionalMaxPoolGradAttr func(optionalAttr)
@@ -15947,6 +17858,23 @@ func ResourceApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator t
return scope.AddOperation(opspec)
}
+// Creates a dataset containing elements of first component of `input_dataset` having true in the last component.
+func FilterByLastComponentDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "FilterByLastComponentDataset",
+ Input: []tf.Input{
+ input_dataset,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// CudnnRNNCanonicalToParamsAttr is an optional argument to CudnnRNNCanonicalToParams.
type CudnnRNNCanonicalToParamsAttr func(optionalAttr)
@@ -16368,175 +18296,6 @@ func FractionalAvgPoolGrad(scope *Scope, orig_input_tensor_shape tf.Output, out_
return op.Output(0)
}
-// BoostedTreesEnsembleResourceHandleOpAttr is an optional argument to BoostedTreesEnsembleResourceHandleOp.
-type BoostedTreesEnsembleResourceHandleOpAttr func(optionalAttr)
-
-// BoostedTreesEnsembleResourceHandleOpContainer sets the optional container attribute to value.
-// If not specified, defaults to ""
-func BoostedTreesEnsembleResourceHandleOpContainer(value string) BoostedTreesEnsembleResourceHandleOpAttr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// BoostedTreesEnsembleResourceHandleOpSharedName sets the optional shared_name attribute to value.
-// If not specified, defaults to ""
-func BoostedTreesEnsembleResourceHandleOpSharedName(value string) BoostedTreesEnsembleResourceHandleOpAttr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// Creates a handle to a BoostedTreesEnsembleResource
-func BoostedTreesEnsembleResourceHandleOp(scope *Scope, optional ...BoostedTreesEnsembleResourceHandleOpAttr) (resource tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "BoostedTreesEnsembleResourceHandleOp",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// ResourceApplyMomentumAttr is an optional argument to ResourceApplyMomentum.
-type ResourceApplyMomentumAttr func(optionalAttr)
-
-// ResourceApplyMomentumUseLocking sets the optional use_locking attribute to value.
-//
-// value: If `True`, updating of the var and accum tensors will be protected
-// by a lock; otherwise the behavior is undefined, but may exhibit less
-// contention.
-// If not specified, defaults to false
-func ResourceApplyMomentumUseLocking(value bool) ResourceApplyMomentumAttr {
- return func(m optionalAttr) {
- m["use_locking"] = value
- }
-}
-
-// ResourceApplyMomentumUseNesterov sets the optional use_nesterov attribute to value.
-//
-// value: If `True`, the tensor passed to compute grad will be
-// var - lr * momentum * accum, so in the end, the var you get is actually
-// var - lr * momentum * accum.
-// If not specified, defaults to false
-func ResourceApplyMomentumUseNesterov(value bool) ResourceApplyMomentumAttr {
- return func(m optionalAttr) {
- m["use_nesterov"] = value
- }
-}
-
-// Update '*var' according to the momentum scheme. Set use_nesterov = True if you
-//
-// want to use Nesterov momentum.
-//
-// accum = accum * momentum + grad
-// var -= lr * accum
-//
-// Arguments:
-// var_: Should be from a Variable().
-// accum: Should be from a Variable().
-// lr: Scaling factor. Must be a scalar.
-// grad: The gradient.
-// momentum: Momentum. Must be a scalar.
-//
-// Returns the created operation.
-func ResourceApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, momentum tf.Output, optional ...ResourceApplyMomentumAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResourceApplyMomentum",
- Input: []tf.Input{
- var_, accum, lr, grad, momentum,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
-// MaxPoolGradGradAttr is an optional argument to MaxPoolGradGrad.
-type MaxPoolGradGradAttr func(optionalAttr)
-
-// MaxPoolGradGradDataFormat sets the optional data_format attribute to value.
-//
-// value: Specify the data format of the input and output data. With the
-// default format "NHWC", the data is stored in the order of:
-// [batch, in_height, in_width, in_channels].
-// Alternatively, the format could be "NCHW", the data storage order of:
-// [batch, in_channels, in_height, in_width].
-// If not specified, defaults to "NHWC"
-func MaxPoolGradGradDataFormat(value string) MaxPoolGradGradAttr {
- return func(m optionalAttr) {
- m["data_format"] = value
- }
-}
-
-// Computes second-order gradients of the maxpooling function.
-//
-// Arguments:
-// orig_input: The original input tensor.
-// orig_output: The original output tensor.
-// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`.
-// ksize: The size of the window for each dimension of the input tensor.
-// strides: The stride of the sliding window for each dimension of the
-// input tensor.
-// padding: The type of padding algorithm to use.
-//
-// Returns Gradients of gradients w.r.t. the input to `max_pool`.
-func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "MaxPoolGradGrad",
- Input: []tf.Input{
- orig_input, orig_output, grad,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns the last element of the input list as well as a list with all but that element.
-//
-// Fails if the list is empty.
-//
-// input_handle: the input list
-// tensor: the withdrawn last element of the list
-// element_dtype: the type of elements in the list
-// element_shape: the shape of the output tensor
-func TensorListPopBack(scope *Scope, input_handle tf.Output, element_dtype tf.DataType) (output_handle tf.Output, tensor tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"element_dtype": element_dtype}
- opspec := tf.OpSpec{
- Type: "TensorListPopBack",
- Input: []tf.Input{
- input_handle,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
// Returns element-wise integer closest to x.
//
// If the result is midway between two representable values,
@@ -16806,7 +18565,8 @@ func DecodeCSVSelectCols(value []int64) DecodeCSVAttr {
// records: Each string is a record/row in the csv and all records should have
// the same format.
// record_defaults: One tensor per column of the input record, with either a
-// scalar default value for that column or empty if the column is required.
+// scalar default value for that column or an empty vector if the column is
+// required.
//
// Returns Each tensor will have the same shape as records.
func DecodeCSV(scope *Scope, records tf.Output, record_defaults []tf.Output, optional ...DecodeCSVAttr) (output []tf.Output) {
@@ -17573,8 +19333,9 @@ func ReaderNumRecordsProducedV2(scope *Scope, reader_handle tf.Output) (records_
// Computes the sum along segments of a tensor.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
// Computes a tensor such that
// \\(output_i = \sum_j data_j\\) where sum is over `j` such
@@ -17588,7 +19349,7 @@ func ReaderNumRecordsProducedV2(scope *Scope, reader_handle tf.Output) (records_
//
// Arguments:
//
-// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
// first dimension. Values should be sorted and can be repeated.
//
// Returns Has same shape as data, except for dimension 0 which
@@ -17825,31 +19586,6 @@ func SparseDenseCwiseAdd(scope *Scope, sp_indices tf.Output, sp_values tf.Output
return op.Output(0)
}
-// Read an element from the TensorArray into output `value`.
-//
-// Arguments:
-// handle: The handle to a TensorArray.
-//
-// flow_in: A float scalar that enforces proper chaining of operations.
-// dtype: The type of the elem that is returned.
-//
-// Returns The tensor that is read from the TensorArray.
-func TensorArrayReadV3(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtype": dtype}
- opspec := tf.OpSpec{
- Type: "TensorArrayReadV3",
- Input: []tf.Input{
- handle, index, flow_in,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// QuantizeV2Attr is an optional argument to QuantizeV2.
type QuantizeV2Attr func(optionalAttr)
@@ -18814,27 +20550,6 @@ func OptimizeDataset(scope *Scope, input_dataset tf.Output, optimizations tf.Out
return op.Output(0)
}
-// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics.
-//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble.
-//
-// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest
-// layer.
-func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BoostedTreesGetEnsembleStates",
- Input: []tf.Input{
- tree_ensemble_handle,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
-}
-
// Returns the element-wise min of two SparseTensors.
//
// Assumes the two SparseTensors have the same shape, i.e., no broadcasting.
@@ -19268,6 +20983,201 @@ func Sum(scope *Scope, input tf.Output, axis tf.Output, optional ...SumAttr) (ou
return op.Output(0)
}
+// EnterAttr is an optional argument to Enter.
+type EnterAttr func(optionalAttr)
+
+// EnterIsConstant sets the optional is_constant attribute to value.
+//
+// value: If true, the output is constant within the child frame.
+// If not specified, defaults to false
+func EnterIsConstant(value bool) EnterAttr {
+ return func(m optionalAttr) {
+ m["is_constant"] = value
+ }
+}
+
+// EnterParallelIterations sets the optional parallel_iterations attribute to value.
+//
+// value: The number of iterations allowed to run in parallel.
+// If not specified, defaults to 10
+func EnterParallelIterations(value int64) EnterAttr {
+ return func(m optionalAttr) {
+ m["parallel_iterations"] = value
+ }
+}
+
+// Creates or finds a child frame, and makes `data` available to the child frame.
+//
+// This op is used together with `Exit` to create loops in the graph.
+// The unique `frame_name` is used by the `Executor` to identify frames. If
+// `is_constant` is true, `output` is a constant in the child frame; otherwise
+// it may be changed in the child frame. At most `parallel_iterations` iterations
+// are run in parallel in the child frame.
+//
+// Arguments:
+// data: The tensor to be made available to the child frame.
+// frame_name: The name of the child frame.
+//
+// Returns The same tensor as `data`.
+func Enter(scope *Scope, data tf.Output, frame_name string, optional ...EnterAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"frame_name": frame_name}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Enter",
+ Input: []tf.Input{
+ data,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Add all input tensors element wise.
+//
+// Arguments:
+// inputs: Must all be the same size and shape.
+func AddN(scope *Scope, inputs []tf.Output) (sum tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "AddN",
+ Input: []tf.Input{
+ tf.OutputList(inputs),
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// TryRpcAttr is an optional argument to TryRpc.
+type TryRpcAttr func(optionalAttr)
+
+// TryRpcProtocol sets the optional protocol attribute to value.
+//
+// value: RPC protocol to use. Empty string means use the default protocol.
+// Options include 'grpc'.
+// If not specified, defaults to ""
+func TryRpcProtocol(value string) TryRpcAttr {
+ return func(m optionalAttr) {
+ m["protocol"] = value
+ }
+}
+
+// TryRpcFailFast sets the optional fail_fast attribute to value.
+//
+// value: `boolean`. If `true` (default), then failures to connect
+// (i.e., the server does not immediately respond) cause an RPC failure.
+// If not specified, defaults to true
+func TryRpcFailFast(value bool) TryRpcAttr {
+ return func(m optionalAttr) {
+ m["fail_fast"] = value
+ }
+}
+
+// TryRpcTimeoutInMs sets the optional timeout_in_ms attribute to value.
+//
+// value: `int`. If `0` (default), then the kernel will run the RPC
+// request and only time out if the RPC deadline passes or the session times out.
+// If this value is greater than `0`, then the op will raise an exception if
+// the RPC takes longer than `timeout_in_ms`.
+// If not specified, defaults to 0
+func TryRpcTimeoutInMs(value int64) TryRpcAttr {
+ return func(m optionalAttr) {
+ m["timeout_in_ms"] = value
+ }
+}
+
+// Perform batches of RPC requests.
+//
+// This op asynchronously performs either a single RPC request, or a batch
+// of requests. RPC requests are defined by three main parameters:
+//
+// - `address` (the host+port or BNS address of the request)
+// - `method` (the method name for the request)
+// - `request` (the serialized proto string, or vector of strings,
+// of the RPC request argument).
+//
+// For example, if you have an RPC service running on port localhost:2345,
+// and its interface is configured with the following proto declaration:
+//
+// ```
+// service MyService {
+// rpc MyMethod(MyRequestProto) returns (MyResponseProto) {
+// }
+// };
+// ```
+//
+// then call this op with arguments:
+//
+// ```
+// address = "localhost:2345"
+// method = "MyService/MyMethod"
+// ```
+//
+// The `request` tensor is a string tensor representing serialized `MyRequestProto`
+// strings; and the output string tensor `response` will have the same shape
+// and contain (upon successful completion) corresponding serialized
+// `MyResponseProto` strings.
+//
+// For example, to send a single, empty, `MyRequestProto`, call
+// this op with `request = ""`. To send 5 **parallel** empty requests,
+// call this op with `request = ["", "", "", "", ""]`.
+//
+// More generally, one can create a batch of `MyRequestProto` serialized protos
+// from regular batched tensors using the `encode_proto` op, and convert
+// the response `MyResponseProto` serialized protos to batched tensors
+// using the `decode_proto` op.
+//
+// **NOTE** Working with serialized proto strings is faster than instantiating
+// actual proto objects in memory, so no performance degradation is expected
+// compared to writing custom kernels for this workflow.
+//
+// Unlike the standard `Rpc` op, if the connection fails or the remote worker
+// returns an error status, this op does **not** reraise the exception.
+// Instead, the `status_code` and `status_message` entry for the corresponding RPC
+// call is set with the error returned from the RPC call. The `response` tensor
+// will contain valid response values for those minibatch entries whose RPCs did
+// not fail; the rest of the entries will have empty strings.
+//
+// Arguments:
+// address: `0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server.
+// If this tensor has more than 1 element, then multiple parallel rpc requests
+// are sent. This argument broadcasts with `method` and `request`.
+// method: `0-D` or `1-D`. The method address on the RPC server.
+// If this tensor has more than 1 element, then multiple parallel rpc requests
+// are sent. This argument broadcasts with `address` and `request`.
+// request: `0-D` or `1-D`. Serialized proto strings: the rpc request argument.
+// If this tensor has more than 1 element, then multiple parallel rpc requests
+// are sent. This argument broadcasts with `address` and `method`.
+//
+// Returns Same shape as `request`. Serialized proto strings: the rpc responses.Same shape as `request`. Values correspond to tensorflow Status enum codes.Same shape as `request`. Values correspond to Status messages
+// returned from the RPC calls.
+func TryRpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output, optional ...TryRpcAttr) (response tf.Output, status_code tf.Output, status_message tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "TryRpc",
+ Input: []tf.Input{
+ address, method, request,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
// Delete the tensor specified by its handle in the session.
//
// Arguments:
@@ -19505,8 +21415,9 @@ func QuantizedResizeBilinear(scope *Scope, images tf.Output, size tf.Output, min
// Computes the minimum along segments of a tensor.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
// Computes a tensor such that
// \\(output_i = \min_j(data_j)\\) where `min` is over `j` such
@@ -19520,7 +21431,7 @@ func QuantizedResizeBilinear(scope *Scope, images tf.Output, size tf.Output, min
//
// Arguments:
//
-// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
// first dimension. Values should be sorted and can be repeated.
//
// Returns Has same shape as data, except for dimension 0 which
@@ -19634,164 +21545,6 @@ func SdcaOptimizer(scope *Scope, sparse_example_indices []tf.Output, sparse_feat
return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights
}
-// ShapeAttr is an optional argument to Shape.
-type ShapeAttr func(optionalAttr)
-
-// ShapeOutType sets the optional out_type attribute to value.
-// If not specified, defaults to DT_INT32
-func ShapeOutType(value tf.DataType) ShapeAttr {
- return func(m optionalAttr) {
- m["out_type"] = value
- }
-}
-
-// Returns the shape of a tensor.
-//
-// This operation returns a 1-D integer tensor representing the shape of `input`.
-//
-// For example:
-//
-// ```
-// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
-// shape(t) ==> [2, 2, 3]
-// ```
-func Shape(scope *Scope, input tf.Output, optional ...ShapeAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Shape",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes the power of one value to another.
-//
-// Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for
-// corresponding elements in `x` and `y`. For example:
-//
-// ```
-// # tensor 'x' is [[2, 2]], [3, 3]]
-// # tensor 'y' is [[8, 16], [2, 3]]
-// tf.pow(x, y) ==> [[256, 65536], [9, 27]]
-// ```
-func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Pow",
- Input: []tf.Input{
- x, y,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes fingerprints of the input strings.
-//
-// Arguments:
-// input: vector of strings to compute fingerprints on.
-//
-// Returns a (N,2) shaped matrix where N is the number of elements in the input
-// vector. Each row contains the low and high parts of the fingerprint.
-func SdcaFprint(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SdcaFprint",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// RandomPoissonV2Attr is an optional argument to RandomPoissonV2.
-type RandomPoissonV2Attr func(optionalAttr)
-
-// RandomPoissonV2Seed sets the optional seed attribute to value.
-//
-// value: If either `seed` or `seed2` are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func RandomPoissonV2Seed(value int64) RandomPoissonV2Attr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// RandomPoissonV2Seed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func RandomPoissonV2Seed2(value int64) RandomPoissonV2Attr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// RandomPoissonV2Dtype sets the optional dtype attribute to value.
-// If not specified, defaults to DT_INT64
-func RandomPoissonV2Dtype(value tf.DataType) RandomPoissonV2Attr {
- return func(m optionalAttr) {
- m["dtype"] = value
- }
-}
-
-// Outputs random values from the Poisson distribution(s) described by rate.
-//
-// This op uses two algorithms, depending on rate. If rate >= 10, then
-// the algorithm by Hormann is used to acquire samples via
-// transformation-rejection.
-// See http://www.sciencedirect.com/science/article/pii/0167668793909974.
-//
-// Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform
-// random variables.
-// See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer
-// Programming, Volume 2. Addison Wesley
-//
-// Arguments:
-// shape: 1-D integer tensor. Shape of independent samples to draw from each
-// distribution described by the shape parameters given in rate.
-// rate: A tensor in which each scalar is a "rate" parameter describing the
-// associated poisson distribution.
-//
-// Returns A tensor with shape `shape + shape(rate)`. Each slice
-// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
-// `rate[i0, i1, ...iN]`.
-func RandomPoissonV2(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonV2Attr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "RandomPoissonV2",
- Input: []tf.Input{
- shape, rate,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// MatrixTriangularSolveAttr is an optional argument to MatrixTriangularSolve.
type MatrixTriangularSolveAttr func(optionalAttr)
@@ -19902,76 +21655,6 @@ func RangeDataset(scope *Scope, start tf.Output, stop tf.Output, step tf.Output,
return op.Output(0)
}
-// DepthwiseConv2dNativeBackpropInputAttr is an optional argument to DepthwiseConv2dNativeBackpropInput.
-type DepthwiseConv2dNativeBackpropInputAttr func(optionalAttr)
-
-// DepthwiseConv2dNativeBackpropInputDataFormat sets the optional data_format attribute to value.
-//
-// value: Specify the data format of the input and output data. With the
-// default format "NHWC", the data is stored in the order of:
-// [batch, height, width, channels].
-// Alternatively, the format could be "NCHW", the data storage order of:
-// [batch, channels, height, width].
-// If not specified, defaults to "NHWC"
-func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dNativeBackpropInputAttr {
- return func(m optionalAttr) {
- m["data_format"] = value
- }
-}
-
-// DepthwiseConv2dNativeBackpropInputDilations sets the optional dilations attribute to value.
-//
-// value: 1-D tensor of length 4. The dilation factor for each dimension of
-// `input`. If set to k > 1, there will be k-1 skipped cells between each filter
-// element on that dimension. The dimension order is determined by the value of
-// `data_format`, see above for details. Dilations in the batch and depth
-// dimensions must be 1.
-// If not specified, defaults to <i:1 i:1 i:1 i:1 >
-func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr {
- return func(m optionalAttr) {
- m["dilations"] = value
- }
-}
-
-// Computes the gradients of depthwise convolution with respect to the input.
-//
-// Arguments:
-// input_sizes: An integer vector representing the shape of `input`, based
-// on `data_format`. For example, if `data_format` is 'NHWC' then
-// `input` is a 4-D `[batch, height, width, channels]` tensor.
-// filter: 4-D with shape
-// `[filter_height, filter_width, in_channels, depthwise_multiplier]`.
-// out_backprop: 4-D with shape based on `data_format`.
-// For example, if `data_format` is 'NHWC' then
-// out_backprop shape is `[batch, out_height, out_width, out_channels]`.
-// Gradients w.r.t. the output of the convolution.
-// strides: The stride of the sliding window for each dimension of the input
-// of the convolution.
-// padding: The type of padding algorithm to use.
-//
-// Returns 4-D with shape according to `data_format`. For example, if
-// `data_format` is 'NHWC', output shape is `[batch, in_height,
-// in_width, in_channels]`. Gradient w.r.t. the input of the
-// convolution.
-func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropInputAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"strides": strides, "padding": padding}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "DepthwiseConv2dNativeBackpropInput",
- Input: []tf.Input{
- input_sizes, filter, out_backprop,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Stops gradient computation.
//
// When executed in a graph, this op outputs its input tensor as-is.
@@ -20241,23 +21924,44 @@ func QuantizeDownAndShrinkRange(scope *Scope, input tf.Output, input_min tf.Outp
return op.Output(0), op.Output(1), op.Output(2)
}
-// Forwards the input to the output.
+// Computes the sum along segments of a tensor.
//
-// This operator represents the loop termination condition used by the
-// "pivot" switches of a loop.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
+//
+// Computes a tensor such that
+// \\(output[i] = \sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
+// that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids`
+// need not be sorted and need not cover all values in the full
+// range of valid values.
+//
+// If the sum is empty for a given segment ID `i`, `output[i] = 0`.
+// If the given segment ID `i` is negative, the value is dropped and will not be
+// added to the sum of the segment.
+//
+// `num_segments` should equal the number of distinct segment IDs.
+//
+// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+// <img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentSum.png" alt>
+// </div>
//
// Arguments:
-// input: A boolean scalar, representing the branch predicate of the Switch op.
//
-// Returns The same tensor as `input`.
-func LoopCond(scope *Scope, input tf.Output) (output tf.Output) {
+// segment_ids: A tensor whose shape is a prefix of `data.shape`.
+//
+//
+// Returns Has same shape as data, except for the first `segment_ids.rank`
+// dimensions, which are replaced with a single dimension which has size
+// `num_segments`.
+func UnsortedSegmentSum(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "LoopCond",
+ Type: "UnsortedSegmentSum",
Input: []tf.Input{
- input,
+ data, segment_ids, num_segments,
},
}
op := scope.AddOperation(opspec)
@@ -20266,27 +21970,31 @@ func LoopCond(scope *Scope, input tf.Output) (output tf.Output) {
// Computes the product along segments of a tensor.
//
-// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+// for an explanation of segments.
//
// This operator is similar to the unsorted segment sum operator found
// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
// Instead of computing the sum over segments, it computes the product of all
// entries belonging to a segment such that:
//
-// \\(output_i = \prod_j data_j\\) where the product is over `j` such
-// that `segment_ids[j] == i`.
+// \\(output_i = \prod_{j...} data[j...]\\) where the product is over tuples
+// `j...` such that `segment_ids[j...] == i`.
//
// If there is no entry for a given segment ID `i`, it outputs 1.
//
+// If the given segment ID `i` is negative, then the corresponding value is
+// dropped, and will not be included in the result.
+//
// Arguments:
//
-// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
-// first dimension.
+// segment_ids: A tensor whose shape is a prefix of `data.shape`.
//
//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `num_segments`.
+// Returns Has same shape as data, except for the first `segment_ids.rank`
+// dimensions, which are replaced with a single dimension which has size
+// `num_segments`.
func UnsortedSegmentProd(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
@@ -20301,90 +22009,172 @@ func UnsortedSegmentProd(scope *Scope, data tf.Output, segment_ids tf.Output, nu
return op.Output(0)
}
-// RandomUniformIntAttr is an optional argument to RandomUniformInt.
-type RandomUniformIntAttr func(optionalAttr)
-
-// RandomUniformIntSeed sets the optional seed attribute to value.
+// Computes the mean along sparse segments of a tensor.
//
-// value: If either `seed` or `seed2` are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func RandomUniformIntSeed(value int64) RandomUniformIntAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// RandomUniformIntSeed2 sets the optional seed2 attribute to value.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func RandomUniformIntSeed2(value int64) RandomUniformIntAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
+// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
+// dimension, selecting a subset of dimension 0, specified by `indices`.
+//
+// Arguments:
+//
+// indices: A 1-D tensor. Has same rank as `segment_ids`.
+// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+//
+// Returns Has same shape as data, except for dimension 0 which
+// has size `k`, the number of segments.
+func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
}
+ opspec := tf.OpSpec{
+ Type: "SparseSegmentMean",
+ Input: []tf.Input{
+ data, indices, segment_ids,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
}
-// Outputs random integers from a uniform distribution.
-//
-// The generated values are uniform integers in the range `[minval, maxval)`.
-// The lower bound `minval` is included in the range, while the upper bound
-// `maxval` is excluded.
+// Deserializes a serialized tree ensemble config and replaces current tree
//
-// The random integers are slightly biased unless `maxval - minval` is an exact
-// power of two. The bias is small for values of `maxval - minval` significantly
-// smaller than the range of the output (either `2^32` or `2^64`).
+// ensemble.
//
// Arguments:
-// shape: The shape of the output tensor.
-// minval: 0-D. Inclusive lower bound on the generated integers.
-// maxval: 0-D. Exclusive upper bound on the generated integers.
+// tree_ensemble_handle: Handle to the tree ensemble.
+// stamp_token: Token to use as the new value of the resource stamp.
+// tree_ensemble_serialized: Serialized proto of the ensemble.
//
-// Returns A tensor of the specified shape filled with uniform random integers.
-func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf.Output, optional ...RandomUniformIntAttr) (output tf.Output) {
+// Returns the created operation.
+func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesDeserializeEnsemble",
+ Input: []tf.Input{
+ tree_ensemble_handle, stamp_token, tree_ensemble_serialized,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Transforms a tf.Example proto (as a string) into typed tensors.
+//
+// Arguments:
+// serialized: A vector containing a batch of binary serialized Example protos.
+// dense_defaults: A list of Tensors (some may be empty), whose length matches
+// the length of `dense_keys`. dense_defaults[j] provides default values
+// when the example's feature_map lacks dense_key[j]. If an empty Tensor is
+// provided for dense_defaults[j], then the Feature dense_keys[j] is required.
+// The input type is inferred from dense_defaults[j], even when it's empty.
+// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined,
+// then the shape of dense_defaults[j] must match that of dense_shapes[j].
+// If dense_shapes[j] has an undefined major dimension (variable strides dense
+// feature), dense_defaults[j] must contain a single element:
+// the padding element.
+// num_sparse: The number of sparse features to be parsed from the example. This
+// must match the lengths of `sparse_keys` and `sparse_types`.
+// sparse_keys: A list of `num_sparse` strings.
+// The keys expected in the Examples' features associated with sparse values.
+// dense_keys: The keys expected in the Examples' features associated with dense
+// values.
+// sparse_types: A list of `num_sparse` types; the data types of data in each
+// Feature given in sparse_keys.
+// Currently the ParseSingleExample op supports DT_FLOAT (FloatList),
+// DT_INT64 (Int64List), and DT_STRING (BytesList).
+// dense_shapes: The shapes of data in each Feature given in dense_keys.
+// The length of this list must match the length of `dense_keys`. The
+// number of elements in the Feature corresponding to dense_key[j] must
+// always equal dense_shapes[j].NumEntries(). If dense_shapes[j] ==
+// (D0, D1, ..., DN) then the shape of output Tensor dense_values[j]
+// will be (D0, D1, ..., DN): In the case dense_shapes[j] = (-1, D1,
+// ..., DN), the shape of the output Tensor dense_values[j] will be (M,
+// D1, .., DN), where M is the number of blocks of elements of length
+// D1 * .... * DN, in the input.
+func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.Output, num_sparse int64, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) {
+ if scope.Err() != nil {
+ return
}
+ attrs := map[string]interface{}{"num_sparse": num_sparse, "sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes}
opspec := tf.OpSpec{
- Type: "RandomUniformInt",
+ Type: "ParseSingleExample",
Input: []tf.Input{
- shape, minval, maxval,
+ serialized, tf.OutputList(dense_defaults),
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil {
+ scope.UpdateErr("ParseSingleExample", err)
+ return
+ }
+ if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil {
+ scope.UpdateErr("ParseSingleExample", err)
+ return
+ }
+ if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil {
+ scope.UpdateErr("ParseSingleExample", err)
+ return
+ }
+ if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil {
+ scope.UpdateErr("ParseSingleExample", err)
+ return
+ }
+ return sparse_indices, sparse_values, sparse_shapes, dense_values
}
-// Computes the mean along sparse segments of a tensor.
-//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2.
+type WholeFileReaderV2Attr func(optionalAttr)
+
+// WholeFileReaderV2Container sets the optional container attribute to value.
//
-// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
-// dimension, selecting a subset of dimension 0, specified by `indices`.
+// value: If non-empty, this reader is placed in the given container.
+// Otherwise, a default container is used.
+// If not specified, defaults to ""
+func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// WholeFileReaderV2SharedName sets the optional shared_name attribute to value.
//
-// Arguments:
+// value: If non-empty, this reader is named in the given bucket
+// with this shared_name. Otherwise, the node name is used instead.
+// If not specified, defaults to ""
+func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// A Reader that outputs the entire contents of a file as a value.
//
-// indices: A 1-D tensor. Has same rank as `segment_ids`.
-// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+// To use, enqueue filenames in a Queue. The output of ReaderRead will
+// be a filename (key) and the contents of that file (value).
//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `k`, the number of segments.
-func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) {
+// Returns The handle to reference the Reader.
+func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "SparseSegmentMean",
- Input: []tf.Input{
- data, indices, segment_ids,
- },
+ Type: "WholeFileReaderV2",
+
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
@@ -20433,8 +22223,9 @@ func Cosh(scope *Scope, x tf.Output) (y tf.Output) {
// Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is
// misisng, the `output` tensor at that position will be zeroed.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
// Arguments:
//
@@ -20579,8 +22370,9 @@ func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segm
//
// N is the size of the segment being reduced.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
// Arguments:
//
@@ -20638,8 +22430,9 @@ func Igammac(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) {
// Like `SparseSegmentSqrtN`, but allows missing ids in `segment_ids`. If an id is
// misisng, the `output` tensor at that position will be zeroed.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
// Arguments:
//
@@ -20802,40 +22595,6 @@ func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (ou
return op.Output(0)
}
-// Creates a sequence of numbers.
-//
-// This operation creates a sequence of numbers that begins at `start` and
-// extends by increments of `delta` up to but not including `limit`.
-//
-// For example:
-//
-// ```
-// # 'start' is 3
-// # 'limit' is 18
-// # 'delta' is 3
-// tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15]
-// ```
-//
-// Arguments:
-// start: 0-D (scalar). First entry in the sequence.
-// limit: 0-D (scalar). Upper limit of sequence, exclusive.
-// delta: 0-D (scalar). Optional. Default is 1. Number that increments `start`.
-//
-// Returns 1-D.
-func Range(scope *Scope, start tf.Output, limit tf.Output, delta tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Range",
- Input: []tf.Input{
- start, limit, delta,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// DestroyResourceOpAttr is an optional argument to DestroyResourceOp.
type DestroyResourceOpAttr func(optionalAttr)
@@ -21000,8 +22759,9 @@ func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output
// Computes the maximum along segments of a tensor.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
// Computes a tensor such that
// \\(output_i = \max_j(data_j)\\) where `max` is over `j` such
@@ -21015,7 +22775,7 @@ func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output
//
// Arguments:
//
-// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
// first dimension. Values should be sorted and can be repeated.
//
// Returns Has same shape as data, except for dimension 0 which
@@ -21899,156 +23659,6 @@ func LookupTableFindV2(scope *Scope, table_handle tf.Output, keys tf.Output, def
return op.Output(0)
}
-// Bucketizes 'input' based on 'boundaries'.
-//
-// For example, if the inputs are
-// boundaries = [0, 10, 100]
-// input = [[-5, 10000]
-// [150, 10]
-// [5, 100]]
-//
-// then the output will be
-// output = [[0, 3]
-// [3, 2]
-// [1, 3]]
-//
-// Arguments:
-// input: Any shape of Tensor contains with int or float type.
-// boundaries: A sorted list of floats gives the boundary of the buckets.
-//
-// Returns Same shape with 'input', each value of input replaced with bucket index.
-//
-// @compatibility(numpy)
-// Equivalent to np.digitize.
-// @end_compatibility
-func Bucketize(scope *Scope, input tf.Output, boundaries []float32) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"boundaries": boundaries}
- opspec := tf.OpSpec{
- Type: "Bucketize",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Calculates gains for each feature and returns the best possible split information for the feature.
-//
-// The split information is the best threshold (bucket id), gains and left/right node contributions per node for each feature.
-//
-// It is possible that not all nodes can be split on each feature. Hence, the list of possible nodes can differ between the features. Therefore, we return `node_ids_list` for each feature, containing the list of nodes that this feature can be used to split.
-//
-// In this manner, the output is the best split per features and per node, so that it needs to be combined later to produce the best split for each node (among all possible features).
-//
-// The length of output lists are all of the same length, `num_features`.
-// The output shapes are compatible in a way that the first dimension of all tensors of all lists are the same and equal to the number of possible split nodes for each feature.
-//
-// Arguments:
-// node_id_range: A Rank 1 tensor (shape=[2]) to specify the range [first, last) of node ids to process within `stats_summary_list`. The nodes are iterated between the two nodes specified by the tensor, as like `for node_id in range(node_id_range[0], node_id_range[1])` (Note that the last index node_id_range[1] is exclusive).
-// stats_summary_list: A list of Rank 3 tensor (#shape=[max_splits, bucket, 2]) for accumulated stats summary (gradient/hessian) per node per buckets for each feature. The first dimension of the tensor is the maximum number of splits, and thus not all elements of it will be used, but only the indexes specified by node_ids will be used.
-// l1: l1 regularization factor on leaf weights, per instance based.
-// l2: l2 regularization factor on leaf weights, per instance based.
-// tree_complexity: adjustment to the gain, per leaf based.
-// min_node_weight: mininum avg of hessians in a node before required for the node to be considered for splitting.
-// max_splits: the number of nodes that can be split in the whole tree. Used as a dimension of output tensors.
-//
-// Returns An output list of Rank 1 tensors indicating possible split node ids for each feature. The length of the list is num_features, but each tensor has different size as each feature provides different possible nodes. See above for details like shapes and sizes.An output list of Rank 1 tensors indicating the best gains for each feature to split for certain nodes. See above for details like shapes and sizes.An output list of Rank 1 tensors indicating the bucket id to compare with (as a threshold) for split in each node. See above for details like shapes and sizes.A list of Rank 2 tensors indicating the contribution of the left nodes when branching from parent nodes (given by the tensor element in the output node_ids_list) to the left direction by the given threshold for each feature. This value will be used to make the left node value by adding to the parent node value. Second dimension size is 1 for 1-dimensional logits, but would be larger for multi-class problems. See above for details like shapes and sizes.A list of Rank 2 tensors, with the same shape/conditions as left_node_contribs_list, but just that the value is for the right node.
-func BoostedTreesCalculateBestGainsPerFeature(scope *Scope, node_id_range tf.Output, stats_summary_list []tf.Output, l1 tf.Output, l2 tf.Output, tree_complexity tf.Output, min_node_weight tf.Output, max_splits int64) (node_ids_list []tf.Output, gains_list []tf.Output, thresholds_list []tf.Output, left_node_contribs_list []tf.Output, right_node_contribs_list []tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"max_splits": max_splits}
- opspec := tf.OpSpec{
- Type: "BoostedTreesCalculateBestGainsPerFeature",
- Input: []tf.Input{
- node_id_range, tf.OutputList(stats_summary_list), l1, l2, tree_complexity, min_node_weight,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if node_ids_list, idx, err = makeOutputList(op, idx, "node_ids_list"); err != nil {
- scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err)
- return
- }
- if gains_list, idx, err = makeOutputList(op, idx, "gains_list"); err != nil {
- scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err)
- return
- }
- if thresholds_list, idx, err = makeOutputList(op, idx, "thresholds_list"); err != nil {
- scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err)
- return
- }
- if left_node_contribs_list, idx, err = makeOutputList(op, idx, "left_node_contribs_list"); err != nil {
- scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err)
- return
- }
- if right_node_contribs_list, idx, err = makeOutputList(op, idx, "right_node_contribs_list"); err != nil {
- scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err)
- return
- }
- return node_ids_list, gains_list, thresholds_list, left_node_contribs_list, right_node_contribs_list
-}
-
-// EncodePngAttr is an optional argument to EncodePng.
-type EncodePngAttr func(optionalAttr)
-
-// EncodePngCompression sets the optional compression attribute to value.
-//
-// value: Compression level.
-// If not specified, defaults to -1
-func EncodePngCompression(value int64) EncodePngAttr {
- return func(m optionalAttr) {
- m["compression"] = value
- }
-}
-
-// PNG-encode an image.
-//
-// `image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]`
-// where `channels` is:
-//
-// * 1: for grayscale.
-// * 2: for grayscale + alpha.
-// * 3: for RGB.
-// * 4: for RGBA.
-//
-// The ZLIB compression level, `compression`, can be -1 for the PNG-encoder
-// default or a value from 0 to 9. 9 is the highest compression level, generating
-// the smallest output, but is slower.
-//
-// Arguments:
-// image: 3-D with shape `[height, width, channels]`.
-//
-// Returns 0-D. PNG-encoded image.
-func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (contents tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "EncodePng",
- Input: []tf.Input{
- image,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Updates the table to associates keys with values.
//
// The tensor `keys` must be of the same type as the keys of the table.
@@ -22366,6 +23976,58 @@ func HashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, o
return op.Output(0)
}
+// MultiDeviceIteratorFromStringHandleAttr is an optional argument to MultiDeviceIteratorFromStringHandle.
+type MultiDeviceIteratorFromStringHandleAttr func(optionalAttr)
+
+// MultiDeviceIteratorFromStringHandleOutputTypes sets the optional output_types attribute to value.
+//
+// value: The type list for the return values.
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func MultiDeviceIteratorFromStringHandleOutputTypes(value []tf.DataType) MultiDeviceIteratorFromStringHandleAttr {
+ return func(m optionalAttr) {
+ m["output_types"] = value
+ }
+}
+
+// MultiDeviceIteratorFromStringHandleOutputShapes sets the optional output_shapes attribute to value.
+//
+// value: The list of shapes being produced.
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func MultiDeviceIteratorFromStringHandleOutputShapes(value []tf.Shape) MultiDeviceIteratorFromStringHandleAttr {
+ return func(m optionalAttr) {
+ m["output_shapes"] = value
+ }
+}
+
+// Generates a MultiDeviceIterator resource from its provided string handle.
+//
+// Arguments:
+// string_handle: String representing the resource.
+//
+// Returns A MultiDeviceIterator resource.
+func MultiDeviceIteratorFromStringHandle(scope *Scope, string_handle tf.Output, optional ...MultiDeviceIteratorFromStringHandleAttr) (multi_device_iterator tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "MultiDeviceIteratorFromStringHandle",
+ Input: []tf.Input{
+ string_handle,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// MutableHashTableV2Attr is an optional argument to MutableHashTableV2.
type MutableHashTableV2Attr func(optionalAttr)
@@ -22790,6 +24452,31 @@ func TensorSummary(scope *Scope, tensor tf.Output, optional ...TensorSummaryAttr
return op.Output(0)
}
+// Read an element from the TensorArray into output `value`.
+//
+// Arguments:
+// handle: The handle to a TensorArray.
+//
+// flow_in: A float scalar that enforces proper chaining of operations.
+// dtype: The type of the elem that is returned.
+//
+// Returns The tensor that is read from the TensorArray.
+func TensorArrayReadV3(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtype": dtype}
+ opspec := tf.OpSpec{
+ Type: "TensorArrayReadV3",
+ Input: []tf.Input{
+ handle, index, flow_in,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the gradient for the tanh of `x` wrt its input.
//
// Specifically, `grad = dy * (1 - y*y)`, where `y = tanh(x)`, and `dy`
@@ -23431,29 +25118,57 @@ func TensorListSetItem(scope *Scope, input_handle tf.Output, index tf.Output, it
return op.Output(0)
}
-// Computes the matrix exponential of one or more square matrices:
+// Creates a Tensor by indexing into the TensorList.
//
-// DEPRECATED at GraphDef version 27: Use Python implementation tf.linalg.matrix_exponential instead.
+// Each row in the produced Tensor corresponds to the element in the TensorList
+// specified by the given index (see `tf.gather`).
//
-// \\(exp(A) = \sum_{n=0}^\infty A^n/n!\\)
-//
-// The exponential is computed using a combination of the scaling and squaring
-// method and the Pade approximation. Details can be founds in:
-// Nicholas J. Higham, "The scaling and squaring method for the matrix exponential
-// revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005.
+// input_handle: The input tensor list.
+// indices: The indices used to index into the list.
+// values: The tensor.
+func TensorListGather(scope *Scope, input_handle tf.Output, indices tf.Output, element_dtype tf.DataType) (values tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"element_dtype": element_dtype}
+ opspec := tf.OpSpec{
+ Type: "TensorListGather",
+ Input: []tf.Input{
+ input_handle, indices,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Creates a TensorList by indexing into a Tensor.
//
-// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
-// form square matrices. The output is a tensor of the same shape as the input
-// containing the exponential for all input submatrices `[..., :, :]`.
+// Each member of the TensorList corresponds to one row of the input tensor,
+// specified by the given index (see `tf.gather`).
//
-// Arguments:
-// input: Shape is `[..., M, M]`.
-//
-// Returns Shape is `[..., M, M]`.
+// tensor: The input tensor.
+// indices: The indices used to index into the list.
+// element_shape: The shape of the elements in the list (can be less specified than
+// the shape of the tensor).
+// output_handle: The TensorList.
+func TensorListScatter(scope *Scope, tensor tf.Output, indices tf.Output, element_shape tf.Output) (output_handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "TensorListScatter",
+ Input: []tf.Input{
+ tensor, indices, element_shape,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Deprecated, use python implementation tf.linalg.matrix_exponential.
//
-// @compatibility(scipy)
-// Equivalent to scipy.linalg.expm
-// @end_compatibility
+// DEPRECATED at GraphDef version 27: Use Python implementation tf.linalg.matrix_exponential instead.
func MatrixExponential(scope *Scope, input tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
@@ -23906,6 +25621,45 @@ func Svd(scope *Scope, input tf.Output, optional ...SvdAttr) (s tf.Output, u tf.
return op.Output(0), op.Output(1), op.Output(2)
}
+// PrintV2Attr is an optional argument to PrintV2.
+type PrintV2Attr func(optionalAttr)
+
+// PrintV2OutputStream sets the optional output_stream attribute to value.
+//
+// value: A string specifying the output stream or logging level to print to.
+// If not specified, defaults to "stderr"
+func PrintV2OutputStream(value string) PrintV2Attr {
+ return func(m optionalAttr) {
+ m["output_stream"] = value
+ }
+}
+
+// Prints a string scalar.
+//
+// Prints a string scalar to the desired output_stream.
+//
+// Arguments:
+// input: The string scalar to print.
+//
+// Returns the created operation.
+func PrintV2(scope *Scope, input tf.Output, optional ...PrintV2Attr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "PrintV2",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
// QueueEnqueueManyV2Attr is an optional argument to QueueEnqueueManyV2.
type QueueEnqueueManyV2Attr func(optionalAttr)
@@ -23959,8 +25713,9 @@ func QueueEnqueueManyV2(scope *Scope, handle tf.Output, components []tf.Output,
// Computes the product along segments of a tensor.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
// Computes a tensor such that
// \\(output_i = \prod_j data_j\\) where the product is over `j` such
@@ -23974,7 +25729,7 @@ func QueueEnqueueManyV2(scope *Scope, handle tf.Output, components []tf.Output,
//
// Arguments:
//
-// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
// first dimension. Values should be sorted and can be repeated.
//
// Returns Has same shape as data, except for dimension 0 which
@@ -24999,7 +26754,7 @@ func ResourceApplyAdamUseNesterov(value bool) ResourceApplyAdamAttr {
// Update '*var' according to the Adam algorithm.
//
-// $$lr_t := \text{learning_rate} * \sqrt{(1 - beta_2^t) / (1 - beta_1^t)}$$
+// $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
// $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
// $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
// $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
@@ -25396,6 +27151,260 @@ func DecodeGif(scope *Scope, contents tf.Output) (image tf.Output) {
return op.Output(0)
}
+// LearnedUnigramCandidateSamplerAttr is an optional argument to LearnedUnigramCandidateSampler.
+type LearnedUnigramCandidateSamplerAttr func(optionalAttr)
+
+// LearnedUnigramCandidateSamplerSeed sets the optional seed attribute to value.
+//
+// value: If either seed or seed2 are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func LearnedUnigramCandidateSamplerSeed(value int64) LearnedUnigramCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// LearnedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value.
+//
+// value: An second seed to avoid seed collision.
+// If not specified, defaults to 0
+func LearnedUnigramCandidateSamplerSeed2(value int64) LearnedUnigramCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Generates labels for candidate sampling with a learned unigram distribution.
+//
+// See explanations of candidate sampling and the data formats at
+// go/candidate-sampling.
+//
+// For each batch, this op picks a single set of sampled candidate labels.
+//
+// The advantages of sampling candidates per-batch are simplicity and the
+// possibility of efficient dense matrix multiplication. The disadvantage is that
+// the sampled candidates must be chosen independently of the context and of the
+// true labels.
+//
+// Arguments:
+// true_classes: A batch_size * num_true matrix, in which each row contains the
+// IDs of the num_true target_classes in the corresponding original label.
+// num_true: Number of true labels per context.
+// num_sampled: Number of candidates to randomly sample.
+// unique: If unique is true, we sample with rejection, so that all sampled
+// candidates in a batch are unique. This requires some approximation to
+// estimate the post-rejection sampling probabilities.
+// range_max: The sampler will sample integers from the interval [0, range_max).
+//
+// Returns A vector of length num_sampled, in which each element is
+// the ID of a sampled candidate.A batch_size * num_true matrix, representing
+// the number of times each candidate is expected to occur in a batch
+// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
+// candidate representing the number of times the candidate is expected
+// to occur in a batch of sampled candidates. If unique=true, then this is a
+// probability.
+func LearnedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LearnedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "LearnedUnigramCandidateSampler",
+ Input: []tf.Input{
+ true_classes,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// SerializeSparseAttr is an optional argument to SerializeSparse.
+type SerializeSparseAttr func(optionalAttr)
+
+// SerializeSparseOutType sets the optional out_type attribute to value.
+//
+// value: The `dtype` to use for serialization; the supported types are `string`
+// (default) and `variant`.
+// If not specified, defaults to DT_STRING
+func SerializeSparseOutType(value tf.DataType) SerializeSparseAttr {
+ return func(m optionalAttr) {
+ m["out_type"] = value
+ }
+}
+
+// Serialize a `SparseTensor` into a `[3]` `Tensor` object.
+//
+// Arguments:
+// sparse_indices: 2-D. The `indices` of the `SparseTensor`.
+// sparse_values: 1-D. The `values` of the `SparseTensor`.
+// sparse_shape: 1-D. The `shape` of the `SparseTensor`.
+func SerializeSparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeSparseAttr) (serialized_sparse tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "SerializeSparse",
+ Input: []tf.Input{
+ sparse_indices, sparse_values, sparse_shape,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// RandomShuffleQueueV2Attr is an optional argument to RandomShuffleQueueV2.
+type RandomShuffleQueueV2Attr func(optionalAttr)
+
+// RandomShuffleQueueV2Shapes sets the optional shapes attribute to value.
+//
+// value: The shape of each component in a value. The length of this attr must
+// be either 0 or the same as the length of component_types. If the length of
+// this attr is 0, the shapes of queue elements are not constrained, and
+// only one element may be dequeued at a time.
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func RandomShuffleQueueV2Shapes(value []tf.Shape) RandomShuffleQueueV2Attr {
+ return func(m optionalAttr) {
+ m["shapes"] = value
+ }
+}
+
+// RandomShuffleQueueV2Capacity sets the optional capacity attribute to value.
+//
+// value: The upper bound on the number of elements in this queue.
+// Negative numbers mean no limit.
+// If not specified, defaults to -1
+func RandomShuffleQueueV2Capacity(value int64) RandomShuffleQueueV2Attr {
+ return func(m optionalAttr) {
+ m["capacity"] = value
+ }
+}
+
+// RandomShuffleQueueV2MinAfterDequeue sets the optional min_after_dequeue attribute to value.
+//
+// value: Dequeue will block unless there would be this
+// many elements after the dequeue or the queue is closed. This
+// ensures a minimum level of mixing of elements.
+// If not specified, defaults to 0
+func RandomShuffleQueueV2MinAfterDequeue(value int64) RandomShuffleQueueV2Attr {
+ return func(m optionalAttr) {
+ m["min_after_dequeue"] = value
+ }
+}
+
+// RandomShuffleQueueV2Seed sets the optional seed attribute to value.
+//
+// value: If either seed or seed2 is set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, a random seed is used.
+// If not specified, defaults to 0
+func RandomShuffleQueueV2Seed(value int64) RandomShuffleQueueV2Attr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// RandomShuffleQueueV2Seed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func RandomShuffleQueueV2Seed2(value int64) RandomShuffleQueueV2Attr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// RandomShuffleQueueV2Container sets the optional container attribute to value.
+//
+// value: If non-empty, this queue is placed in the given container.
+// Otherwise, a default container is used.
+// If not specified, defaults to ""
+func RandomShuffleQueueV2Container(value string) RandomShuffleQueueV2Attr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// RandomShuffleQueueV2SharedName sets the optional shared_name attribute to value.
+//
+// value: If non-empty, this queue will be shared under the given name
+// across multiple sessions.
+// If not specified, defaults to ""
+func RandomShuffleQueueV2SharedName(value string) RandomShuffleQueueV2Attr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// A queue that randomizes the order of elements.
+//
+// Arguments:
+// component_types: The type of each component in a value.
+//
+// Returns The handle to the queue.
+func RandomShuffleQueueV2(scope *Scope, component_types []tf.DataType, optional ...RandomShuffleQueueV2Attr) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"component_types": component_types}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "RandomShuffleQueueV2",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Draw bounding boxes on a batch of images.
+//
+// Outputs a copy of `images` but draws on top of the pixels zero or more bounding
+// boxes specified by the locations in `boxes`. The coordinates of the each
+// bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`. The
+// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and
+// height of the underlying image.
+//
+// For example, if an image is 100 x 200 pixels (height x width) and the bounding
+// box is `[0.1, 0.2, 0.5, 0.9]`, the upper-left and bottom-right coordinates of
+// the bounding box will be `(40, 10)` to `(180, 50)` (in (x,y) coordinates).
+//
+// Parts of the bounding box may fall outside the image.
+//
+// Arguments:
+// images: 4-D with shape `[batch, height, width, depth]`. A batch of images.
+// boxes: 3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding
+// boxes.
+//
+// Returns 4-D with the same shape as `images`. The batch of input images with
+// bounding boxes drawn on the images.
+func DrawBoundingBoxes(scope *Scope, images tf.Output, boxes tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "DrawBoundingBoxes",
+ Input: []tf.Input{
+ images, boxes,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Gets the next output from the given iterator.
//
// This operation is a synchronous version IteratorGetNext. It should only be used
@@ -26154,178 +28163,6 @@ func FakeParam(scope *Scope, dtype tf.DataType, shape tf.Shape) (output tf.Outpu
return op.Output(0)
}
-// EncodeProtoAttr is an optional argument to EncodeProto.
-type EncodeProtoAttr func(optionalAttr)
-
-// EncodeProtoDescriptorSource sets the optional descriptor_source attribute to value.
-// If not specified, defaults to "local://"
-func EncodeProtoDescriptorSource(value string) EncodeProtoAttr {
- return func(m optionalAttr) {
- m["descriptor_source"] = value
- }
-}
-
-// The op serializes protobuf messages provided in the input tensors.
-//
-// The types of the tensors in `values` must match the schema for the
-// fields specified in `field_names`. All the tensors in `values` must
-// have a common shape prefix, *batch_shape*.
-//
-// The `sizes` tensor specifies repeat counts for each field. The repeat
-// count (last dimension) of a each tensor in `values` must be greater
-// than or equal to corresponding repeat count in `sizes`.
-//
-// A `message_type` name must be provided to give context for the field
-// names. The actual message descriptor can be looked up either in the
-// linked-in descriptor pool or a filename provided by the caller using
-// the `descriptor_source` attribute.
-//
-// The `descriptor_source` attribute selects a source of protocol
-// descriptors to consult when looking up `message_type`. This may be a
-// filename containing a serialized `FileDescriptorSet` message,
-// or the special value `local://`, in which case only descriptors linked
-// into the code will be searched; the filename can be on any filesystem
-// accessible to TensorFlow.
-//
-// You can build a `descriptor_source` file using the `--descriptor_set_out`
-// and `--include_imports` options to the protocol compiler `protoc`.
-//
-// The `local://` database only covers descriptors linked into the
-// code via C++ libraries, not Python imports. You can link in a proto descriptor
-// by creating a cc_library target with alwayslink=1.
-//
-// There are a few special cases in the value mapping:
-//
-// Submessage and group fields must be pre-serialized as TensorFlow strings.
-//
-// TensorFlow lacks support for unsigned int64s, so they must be
-// represented as `tf.int64` with the same twos-complement bit pattern
-// (the obvious way).
-//
-// Unsigned int32 values can be represented exactly with `tf.int64`, or
-// with sign wrapping if the input is of type `tf.int32`.
-//
-// Arguments:
-// sizes: Tensor of int32 with shape `[batch_shape, len(field_names)]`.
-// values: List of tensors containing values for the corresponding field.
-// field_names: List of strings containing proto field names.
-// message_type: Name of the proto message type to decode.
-//
-// Returns Tensor of serialized protos with shape `batch_shape`.
-func EncodeProto(scope *Scope, sizes tf.Output, values []tf.Output, field_names []string, message_type string, optional ...EncodeProtoAttr) (bytes tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"field_names": field_names, "message_type": message_type}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "EncodeProto",
- Input: []tf.Input{
- sizes, tf.OutputList(values),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Creates a TensorArray for storing the gradients of values in the given handle.
-//
-// If the given TensorArray gradient already exists, returns a reference to it.
-//
-// Locks the size of the original TensorArray by disabling its dynamic size flag.
-//
-// **A note about the input flow_in:**
-//
-// The handle flow_in forces the execution of the gradient lookup to occur
-// only after certain other operations have occurred. For example, when
-// the forward TensorArray is dynamically sized, writes to this TensorArray
-// may resize the object. The gradient TensorArray is statically sized based
-// on the size of the forward TensorArray when this operation executes.
-// Furthermore, the size of the forward TensorArray is frozen by this call.
-// As a result, the flow is used to ensure that the call to generate the gradient
-// TensorArray only happens after all writes are executed.
-//
-// In the case of dynamically sized TensorArrays, gradient computation should
-// only be performed on read operations that have themselves been chained via
-// flow to occur only after all writes have executed. That way the final size
-// of the forward TensorArray is known when this operation is called.
-//
-// **A note about the source attribute:**
-//
-// TensorArray gradient calls use an accumulator TensorArray object. If
-// multiple gradients are calculated and run in the same session, the multiple
-// gradient nodes may accidentally flow through the same accumulator TensorArray.
-// This double counts and generally breaks the TensorArray gradient flow.
-//
-// The solution is to identify which gradient call this particular
-// TensorArray gradient is being called in. This is performed by identifying
-// a unique string (e.g. "gradients", "gradients_1", ...) from the input
-// gradient Tensor's name. This string is used as a suffix when creating
-// the TensorArray gradient object here (the attribute `source`).
-//
-// The attribute `source` is added as a suffix to the forward TensorArray's
-// name when performing the creation / lookup, so that each separate gradient
-// calculation gets its own TensorArray accumulator.
-//
-// Arguments:
-// handle: The handle to the forward TensorArray.
-// flow_in: A float scalar that enforces proper chaining of operations.
-// source: The gradient source string, used to decide which gradient TensorArray
-// to return.
-func TensorArrayGradV3(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output, flow_out tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"source": source}
- opspec := tf.OpSpec{
- Type: "TensorArrayGradV3",
- Input: []tf.Input{
- handle, flow_in,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
-// Creates a dataset that splits a SparseTensor into elements row-wise.
-func SparseTensorSliceDataset(scope *Scope, indices tf.Output, values tf.Output, dense_shape tf.Output) (handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SparseTensorSliceDataset",
- Input: []tf.Input{
- indices, values, dense_shape,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns x / y element-wise for real types.
-//
-// If `x` and `y` are reals, this will return the floating-point division.
-//
-// *NOTE*: `Div` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func RealDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "RealDiv",
- Input: []tf.Input{
- x, y,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Adds v into specified rows of x.
//
// Computes y = x; y[i, :] += v; return y.
@@ -26621,6 +28458,255 @@ func StackPushV2(scope *Scope, handle tf.Output, elem tf.Output, optional ...Sta
return op.Output(0)
}
+// StringSplitV2Attr is an optional argument to StringSplitV2.
+type StringSplitV2Attr func(optionalAttr)
+
+// StringSplitV2Maxsplit sets the optional maxsplit attribute to value.
+//
+// value: An `int`. If `maxsplit > 0`, limit of the split of the result.
+// If not specified, defaults to -1
+func StringSplitV2Maxsplit(value int64) StringSplitV2Attr {
+ return func(m optionalAttr) {
+ m["maxsplit"] = value
+ }
+}
+
+// Split elements of `source` based on `sep` into a `SparseTensor`.
+//
+// Let N be the size of source (typically N will be the batch size). Split each
+// element of `source` based on `sep` and return a `SparseTensor`
+// containing the split tokens. Empty tokens are ignored.
+//
+// For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c',
+// then the output will be
+// ```
+// st.indices = [0, 0;
+// 0, 1;
+// 1, 0;
+// 1, 1;
+// 1, 2]
+// st.shape = [2, 3]
+// st.values = ['hello', 'world', 'a', 'b', 'c']
+// ```
+//
+// If `sep` is given, consecutive delimiters are not grouped together and are
+// deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and
+// sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
+// string, consecutive whitespace are regarded as a single separator, and the
+// result will contain no empty strings at the startor end if the string has
+// leading or trailing whitespace.
+//
+// Note that the above mentioned behavior matches python's str.split.
+//
+// Arguments:
+// input: `1-D` string `Tensor`, the strings to split.
+// sep: `0-D` string `Tensor`, the delimiter character.
+func StringSplitV2(scope *Scope, input tf.Output, sep tf.Output, optional ...StringSplitV2Attr) (indices tf.Output, values tf.Output, shape tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "StringSplitV2",
+ Input: []tf.Input{
+ input, sep,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// Computes softsign: `features / (abs(features) + 1)`.
+func Softsign(scope *Scope, features tf.Output) (activations tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Softsign",
+ Input: []tf.Input{
+ features,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// EncodeProtoAttr is an optional argument to EncodeProto.
+type EncodeProtoAttr func(optionalAttr)
+
+// EncodeProtoDescriptorSource sets the optional descriptor_source attribute to value.
+// If not specified, defaults to "local://"
+func EncodeProtoDescriptorSource(value string) EncodeProtoAttr {
+ return func(m optionalAttr) {
+ m["descriptor_source"] = value
+ }
+}
+
+// The op serializes protobuf messages provided in the input tensors.
+//
+// The types of the tensors in `values` must match the schema for the
+// fields specified in `field_names`. All the tensors in `values` must
+// have a common shape prefix, *batch_shape*.
+//
+// The `sizes` tensor specifies repeat counts for each field. The repeat
+// count (last dimension) of a each tensor in `values` must be greater
+// than or equal to corresponding repeat count in `sizes`.
+//
+// A `message_type` name must be provided to give context for the field
+// names. The actual message descriptor can be looked up either in the
+// linked-in descriptor pool or a filename provided by the caller using
+// the `descriptor_source` attribute.
+//
+// The `descriptor_source` attribute selects a source of protocol
+// descriptors to consult when looking up `message_type`. This may be a
+// filename containing a serialized `FileDescriptorSet` message,
+// or the special value `local://`, in which case only descriptors linked
+// into the code will be searched; the filename can be on any filesystem
+// accessible to TensorFlow.
+//
+// You can build a `descriptor_source` file using the `--descriptor_set_out`
+// and `--include_imports` options to the protocol compiler `protoc`.
+//
+// The `local://` database only covers descriptors linked into the
+// code via C++ libraries, not Python imports. You can link in a proto descriptor
+// by creating a cc_library target with alwayslink=1.
+//
+// There are a few special cases in the value mapping:
+//
+// Submessage and group fields must be pre-serialized as TensorFlow strings.
+//
+// TensorFlow lacks support for unsigned int64s, so they must be
+// represented as `tf.int64` with the same twos-complement bit pattern
+// (the obvious way).
+//
+// Unsigned int32 values can be represented exactly with `tf.int64`, or
+// with sign wrapping if the input is of type `tf.int32`.
+//
+// Arguments:
+// sizes: Tensor of int32 with shape `[batch_shape, len(field_names)]`.
+// values: List of tensors containing values for the corresponding field.
+// field_names: List of strings containing proto field names.
+// message_type: Name of the proto message type to decode.
+//
+// Returns Tensor of serialized protos with shape `batch_shape`.
+func EncodeProto(scope *Scope, sizes tf.Output, values []tf.Output, field_names []string, message_type string, optional ...EncodeProtoAttr) (bytes tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"field_names": field_names, "message_type": message_type}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "EncodeProto",
+ Input: []tf.Input{
+ sizes, tf.OutputList(values),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Creates a TensorArray for storing the gradients of values in the given handle.
+//
+// If the given TensorArray gradient already exists, returns a reference to it.
+//
+// Locks the size of the original TensorArray by disabling its dynamic size flag.
+//
+// **A note about the input flow_in:**
+//
+// The handle flow_in forces the execution of the gradient lookup to occur
+// only after certain other operations have occurred. For example, when
+// the forward TensorArray is dynamically sized, writes to this TensorArray
+// may resize the object. The gradient TensorArray is statically sized based
+// on the size of the forward TensorArray when this operation executes.
+// Furthermore, the size of the forward TensorArray is frozen by this call.
+// As a result, the flow is used to ensure that the call to generate the gradient
+// TensorArray only happens after all writes are executed.
+//
+// In the case of dynamically sized TensorArrays, gradient computation should
+// only be performed on read operations that have themselves been chained via
+// flow to occur only after all writes have executed. That way the final size
+// of the forward TensorArray is known when this operation is called.
+//
+// **A note about the source attribute:**
+//
+// TensorArray gradient calls use an accumulator TensorArray object. If
+// multiple gradients are calculated and run in the same session, the multiple
+// gradient nodes may accidentally flow through the same accumulator TensorArray.
+// This double counts and generally breaks the TensorArray gradient flow.
+//
+// The solution is to identify which gradient call this particular
+// TensorArray gradient is being called in. This is performed by identifying
+// a unique string (e.g. "gradients", "gradients_1", ...) from the input
+// gradient Tensor's name. This string is used as a suffix when creating
+// the TensorArray gradient object here (the attribute `source`).
+//
+// The attribute `source` is added as a suffix to the forward TensorArray's
+// name when performing the creation / lookup, so that each separate gradient
+// calculation gets its own TensorArray accumulator.
+//
+// Arguments:
+// handle: The handle to the forward TensorArray.
+// flow_in: A float scalar that enforces proper chaining of operations.
+// source: The gradient source string, used to decide which gradient TensorArray
+// to return.
+func TensorArrayGradV3(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output, flow_out tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"source": source}
+ opspec := tf.OpSpec{
+ Type: "TensorArrayGradV3",
+ Input: []tf.Input{
+ handle, flow_in,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
+// Creates a dataset that splits a SparseTensor into elements row-wise.
+func SparseTensorSliceDataset(scope *Scope, indices tf.Output, values tf.Output, dense_shape tf.Output) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseTensorSliceDataset",
+ Input: []tf.Input{
+ indices, values, dense_shape,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns x / y element-wise for real types.
+//
+// If `x` and `y` are reals, this will return the floating-point division.
+//
+// *NOTE*: `Div` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func RealDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "RealDiv",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Creates a dataset that concatenates `input_dataset` with `another_dataset`.
func ConcatenateDataset(scope *Scope, input_dataset tf.Output, another_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
if scope.Err() != nil {
@@ -27016,8 +29102,10 @@ func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source
// If `len` defines a substring that would extend beyond the length of the input
// string, then as many characters as possible are used.
//
-// If `pos` is negative or specifies a character index larger than any of the input
-// strings, then an `InvalidArgumentError` is thrown.
+// A negative `pos` indicates distance within the string backwards from the end.
+//
+// If `pos` specifies an index which is out of range for any of the input strings,
+// then an `InvalidArgumentError` is thrown.
//
// `pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on
// Op creation.
@@ -27422,35 +29510,6 @@ func MakeIterator(scope *Scope, dataset tf.Output, iterator tf.Output) (o *tf.Op
return scope.AddOperation(opspec)
}
-// Makes the summary of accumulated stats for the batch.
-//
-// The summary stats contains gradients and hessians accumulated into the corresponding node and bucket for each example.
-//
-// Arguments:
-// node_ids: int32 Rank 1 Tensor containing node ids, which each example falls into for the requested layer.
-// gradients: float32; Rank 2 Tensor (shape=[#examples, 1]) for gradients.
-// hessians: float32; Rank 2 Tensor (shape=[#examples, 1]) for hessians.
-// bucketized_features_list: int32 list of Rank 1 Tensors, each containing the bucketized feature (for each feature column).
-// max_splits: int; the maximum number of splits possible in the whole tree.
-// num_buckets: int; equals to the maximum possible value of bucketized feature.
-//
-// Returns output Rank 4 Tensor (shape=[#features, #splits, #buckets, 2]) containing accumulated stats put into the corresponding node and bucket. The first index of 4th dimension refers to gradients, and the second to hessians.
-func BoostedTreesMakeStatsSummary(scope *Scope, node_ids tf.Output, gradients tf.Output, hessians tf.Output, bucketized_features_list []tf.Output, max_splits int64, num_buckets int64) (stats_summary tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"max_splits": max_splits, "num_buckets": num_buckets}
- opspec := tf.OpSpec{
- Type: "BoostedTreesMakeStatsSummary",
- Input: []tf.Input{
- node_ids, gradients, hessians, tf.OutputList(bucketized_features_list),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Adjust the contrast of one or more images.
//
// `images` is a tensor of at least 3 dimensions. The last 3 dimensions are
@@ -27643,6 +29702,8 @@ func IteratorFromStringHandle(scope *Scope, string_handle tf.Output, optional ..
// On GPU, if an out of bound index is found, a 0 is stored in the
// corresponding output value.
//
+// See also `tf.batch_gather` and `tf.gather_nd`.
+//
// Arguments:
// params: The tensor from which to gather values. Must be at least rank
// `axis + 1`.
@@ -28153,6 +30214,30 @@ func FFT(scope *Scope, input tf.Output) (output tf.Output) {
return op.Output(0)
}
+// Identity transformation that models performance.
+//
+// Identity transformation that models performance.
+//
+// Arguments:
+// input_dataset: A variant tensor representing the input dataset.
+//
+//
+func ModelDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "ModelDataset",
+ Input: []tf.Input{
+ input_dataset,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Performs a padding as a preprocess during a convolution.
//
// Similar to FusedResizeAndPadConv2d, this op allows for an optimized
@@ -28842,10 +30927,16 @@ func EncodeBase64(scope *Scope, input tf.Output, optional ...EncodeBase64Attr) (
//
// Arguments:
//
-// window_size: A scalar representing the number of elements to accumulate in a window.
+// size: A scalar representing the number of elements to accumulate in a window.
+// shift: A scalar representing the steps moving the sliding window forward in one
+// iteration. It must be positive.
+// stride: A scalar representing the stride of the input elements of the sliding window.
+// It must be positive.
+// drop_remainder: A scalar representing whether a window should be dropped in case its size is
+// smaller than desired.
//
//
-func WindowDataset(scope *Scope, input_dataset tf.Output, window_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+func WindowDataset(scope *Scope, input_dataset tf.Output, size tf.Output, shift tf.Output, stride tf.Output, drop_remainder tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
if scope.Err() != nil {
return
}
@@ -28853,7 +30944,7 @@ func WindowDataset(scope *Scope, input_dataset tf.Output, window_size tf.Output,
opspec := tf.OpSpec{
Type: "WindowDataset",
Input: []tf.Input{
- input_dataset, window_size,
+ input_dataset, size, shift, stride, drop_remainder,
},
Attrs: attrs,
}
@@ -29542,260 +31633,6 @@ func TensorArraySplitV3(scope *Scope, handle tf.Output, value tf.Output, lengths
return op.Output(0)
}
-// SerializeSparseAttr is an optional argument to SerializeSparse.
-type SerializeSparseAttr func(optionalAttr)
-
-// SerializeSparseOutType sets the optional out_type attribute to value.
-//
-// value: The `dtype` to use for serialization; the supported types are `string`
-// (default) and `variant`.
-// If not specified, defaults to DT_STRING
-func SerializeSparseOutType(value tf.DataType) SerializeSparseAttr {
- return func(m optionalAttr) {
- m["out_type"] = value
- }
-}
-
-// Serialize a `SparseTensor` into a `[3]` `Tensor` object.
-//
-// Arguments:
-// sparse_indices: 2-D. The `indices` of the `SparseTensor`.
-// sparse_values: 1-D. The `values` of the `SparseTensor`.
-// sparse_shape: 1-D. The `shape` of the `SparseTensor`.
-func SerializeSparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeSparseAttr) (serialized_sparse tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "SerializeSparse",
- Input: []tf.Input{
- sparse_indices, sparse_values, sparse_shape,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// RandomShuffleQueueV2Attr is an optional argument to RandomShuffleQueueV2.
-type RandomShuffleQueueV2Attr func(optionalAttr)
-
-// RandomShuffleQueueV2Shapes sets the optional shapes attribute to value.
-//
-// value: The shape of each component in a value. The length of this attr must
-// be either 0 or the same as the length of component_types. If the length of
-// this attr is 0, the shapes of queue elements are not constrained, and
-// only one element may be dequeued at a time.
-// If not specified, defaults to <>
-//
-// REQUIRES: len(value) >= 0
-func RandomShuffleQueueV2Shapes(value []tf.Shape) RandomShuffleQueueV2Attr {
- return func(m optionalAttr) {
- m["shapes"] = value
- }
-}
-
-// RandomShuffleQueueV2Capacity sets the optional capacity attribute to value.
-//
-// value: The upper bound on the number of elements in this queue.
-// Negative numbers mean no limit.
-// If not specified, defaults to -1
-func RandomShuffleQueueV2Capacity(value int64) RandomShuffleQueueV2Attr {
- return func(m optionalAttr) {
- m["capacity"] = value
- }
-}
-
-// RandomShuffleQueueV2MinAfterDequeue sets the optional min_after_dequeue attribute to value.
-//
-// value: Dequeue will block unless there would be this
-// many elements after the dequeue or the queue is closed. This
-// ensures a minimum level of mixing of elements.
-// If not specified, defaults to 0
-func RandomShuffleQueueV2MinAfterDequeue(value int64) RandomShuffleQueueV2Attr {
- return func(m optionalAttr) {
- m["min_after_dequeue"] = value
- }
-}
-
-// RandomShuffleQueueV2Seed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 is set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, a random seed is used.
-// If not specified, defaults to 0
-func RandomShuffleQueueV2Seed(value int64) RandomShuffleQueueV2Attr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// RandomShuffleQueueV2Seed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func RandomShuffleQueueV2Seed2(value int64) RandomShuffleQueueV2Attr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// RandomShuffleQueueV2Container sets the optional container attribute to value.
-//
-// value: If non-empty, this queue is placed in the given container.
-// Otherwise, a default container is used.
-// If not specified, defaults to ""
-func RandomShuffleQueueV2Container(value string) RandomShuffleQueueV2Attr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// RandomShuffleQueueV2SharedName sets the optional shared_name attribute to value.
-//
-// value: If non-empty, this queue will be shared under the given name
-// across multiple sessions.
-// If not specified, defaults to ""
-func RandomShuffleQueueV2SharedName(value string) RandomShuffleQueueV2Attr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// A queue that randomizes the order of elements.
-//
-// Arguments:
-// component_types: The type of each component in a value.
-//
-// Returns The handle to the queue.
-func RandomShuffleQueueV2(scope *Scope, component_types []tf.DataType, optional ...RandomShuffleQueueV2Attr) (handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"component_types": component_types}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "RandomShuffleQueueV2",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Draw bounding boxes on a batch of images.
-//
-// Outputs a copy of `images` but draws on top of the pixels zero or more bounding
-// boxes specified by the locations in `boxes`. The coordinates of the each
-// bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`. The
-// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and
-// height of the underlying image.
-//
-// For example, if an image is 100 x 200 pixels (height x width) and the bounding
-// box is `[0.1, 0.2, 0.5, 0.9]`, the upper-left and bottom-right coordinates of
-// the bounding box will be `(40, 10)` to `(180, 50)` (in (x,y) coordinates).
-//
-// Parts of the bounding box may fall outside the image.
-//
-// Arguments:
-// images: 4-D with shape `[batch, height, width, depth]`. A batch of images.
-// boxes: 3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding
-// boxes.
-//
-// Returns 4-D with the same shape as `images`. The batch of input images with
-// bounding boxes drawn on the images.
-func DrawBoundingBoxes(scope *Scope, images tf.Output, boxes tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "DrawBoundingBoxes",
- Input: []tf.Input{
- images, boxes,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// LearnedUnigramCandidateSamplerAttr is an optional argument to LearnedUnigramCandidateSampler.
-type LearnedUnigramCandidateSamplerAttr func(optionalAttr)
-
-// LearnedUnigramCandidateSamplerSeed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func LearnedUnigramCandidateSamplerSeed(value int64) LearnedUnigramCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// LearnedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value.
-//
-// value: An second seed to avoid seed collision.
-// If not specified, defaults to 0
-func LearnedUnigramCandidateSamplerSeed2(value int64) LearnedUnigramCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Generates labels for candidate sampling with a learned unigram distribution.
-//
-// See explanations of candidate sampling and the data formats at
-// go/candidate-sampling.
-//
-// For each batch, this op picks a single set of sampled candidate labels.
-//
-// The advantages of sampling candidates per-batch are simplicity and the
-// possibility of efficient dense matrix multiplication. The disadvantage is that
-// the sampled candidates must be chosen independently of the context and of the
-// true labels.
-//
-// Arguments:
-// true_classes: A batch_size * num_true matrix, in which each row contains the
-// IDs of the num_true target_classes in the corresponding original label.
-// num_true: Number of true labels per context.
-// num_sampled: Number of candidates to randomly sample.
-// unique: If unique is true, we sample with rejection, so that all sampled
-// candidates in a batch are unique. This requires some approximation to
-// estimate the post-rejection sampling probabilities.
-// range_max: The sampler will sample integers from the interval [0, range_max).
-//
-// Returns A vector of length num_sampled, in which each element is
-// the ID of a sampled candidate.A batch_size * num_true matrix, representing
-// the number of times each candidate is expected to occur in a batch
-// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
-// candidate representing the number of times the candidate is expected
-// to occur in a batch of sampled candidates. If unique=true, then this is a
-// probability.
-func LearnedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LearnedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "LearnedUnigramCandidateSampler",
- Input: []tf.Input{
- true_classes,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// Computes gradients for the scaled exponential linear (Selu) operation.
//
// Arguments:
@@ -30008,27 +31845,6 @@ func TensorArrayScatterV2(scope *Scope, handle tf.Output, indices tf.Output, val
return op.Output(0)
}
-// Creates a tree ensemble model and returns a handle to it.
-//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble resource to be created.
-// stamp_token: Token to use as the initial value of the resource stamp.
-// tree_ensemble_serialized: Serialized proto of the tree ensemble.
-//
-// Returns the created operation.
-func BoostedTreesCreateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BoostedTreesCreateEnsemble",
- Input: []tf.Input{
- tree_ensemble_handle, stamp_token, tree_ensemble_serialized,
- },
- }
- return scope.AddOperation(opspec)
-}
-
// Applies sparse addition to `input` using individual values or slices
//
// from `updates` according to indices `indices`. The updates are non-aliasing:
@@ -30063,7 +31879,7 @@ func BoostedTreesCreateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, st
//
// [1, 13, 3, 14, 14, 6, 7, 20]
//
-// See @{tf.scatter_nd} for more details about how to make updates to slices.
+// See `tf.scatter_nd` for more details about how to make updates to slices.
//
// Arguments:
// input: A Tensor.
@@ -30216,6 +32032,32 @@ func FractionalMaxPool(scope *Scope, value tf.Output, pooling_ratio []float32, o
return op.Output(0), op.Output(1), op.Output(2)
}
+// Creates a MultiDeviceIterator resource.
+//
+// Arguments:
+// devices: A list of devices the iterator works across.
+// shared_name: If non-empty, this resource will be shared under the given name
+// across multiple sessions.
+// container: If non-empty, this resource is placed in the given container.
+// Otherwise, a default container is used.
+// output_types: The type list for the return values.
+// output_shapes: The list of shapes being produced.
+//
+// Returns Handle to the resource created.
+func MultiDeviceIterator(scope *Scope, devices []string, shared_name string, container string, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"devices": devices, "shared_name": shared_name, "container": container, "output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "MultiDeviceIterator",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Deprecated. Use TensorArraySizeV3
//
// DEPRECATED at GraphDef version 26: Use TensorArraySizeV3
@@ -30680,6 +32522,41 @@ func MapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...MapIncomp
return op.Output(0)
}
+// Generate the bucket boundaries for each feature based on accumulated summaries.
+//
+// An op that returns a list of float tensors for a quantile stream resource. Each
+// tensor is Rank 1 containing bucket boundaries for a single feature.
+//
+// Arguments:
+// quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource.
+// num_features: inferred int; number of features to get bucket boundaries for.
+//
+// Returns float; List of Rank 1 Tensors each containing the bucket boundaries for a feature.
+func BoostedTreesQuantileStreamResourceGetBucketBoundaries(scope *Scope, quantile_stream_resource_handle tf.Output, num_features int64) (bucket_boundaries []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_features": num_features}
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesQuantileStreamResourceGetBucketBoundaries",
+ Input: []tf.Input{
+ quantile_stream_resource_handle,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if bucket_boundaries, idx, err = makeOutputList(op, idx, "bucket_boundaries"); err != nil {
+ scope.UpdateErr("BoostedTreesQuantileStreamResourceGetBucketBoundaries", err)
+ return
+ }
+ return bucket_boundaries
+}
+
// OrderedMapUnstageAttr is an optional argument to OrderedMapUnstage.
type OrderedMapUnstageAttr func(optionalAttr)
@@ -30751,6 +32628,43 @@ func OrderedMapUnstage(scope *Scope, key tf.Output, indices tf.Output, dtypes []
return values
}
+// BoostedTreesQuantileStreamResourceHandleOpAttr is an optional argument to BoostedTreesQuantileStreamResourceHandleOp.
+type BoostedTreesQuantileStreamResourceHandleOpAttr func(optionalAttr)
+
+// BoostedTreesQuantileStreamResourceHandleOpContainer sets the optional container attribute to value.
+// If not specified, defaults to ""
+func BoostedTreesQuantileStreamResourceHandleOpContainer(value string) BoostedTreesQuantileStreamResourceHandleOpAttr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// BoostedTreesQuantileStreamResourceHandleOpSharedName sets the optional shared_name attribute to value.
+// If not specified, defaults to ""
+func BoostedTreesQuantileStreamResourceHandleOpSharedName(value string) BoostedTreesQuantileStreamResourceHandleOpAttr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// Creates a handle to a BoostedTreesQuantileStreamResource.
+func BoostedTreesQuantileStreamResourceHandleOp(scope *Scope, optional ...BoostedTreesQuantileStreamResourceHandleOpAttr) (resource tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesQuantileStreamResourceHandleOp",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// OrderedMapSizeAttr is an optional argument to OrderedMapSize.
type OrderedMapSizeAttr func(optionalAttr)
@@ -31077,79 +32991,6 @@ func CudnnRNNParamsToCanonical(scope *Scope, num_layers tf.Output, num_units tf.
return weights, biases
}
-// UniformCandidateSamplerAttr is an optional argument to UniformCandidateSampler.
-type UniformCandidateSamplerAttr func(optionalAttr)
-
-// UniformCandidateSamplerSeed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func UniformCandidateSamplerSeed(value int64) UniformCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// UniformCandidateSamplerSeed2 sets the optional seed2 attribute to value.
-//
-// value: An second seed to avoid seed collision.
-// If not specified, defaults to 0
-func UniformCandidateSamplerSeed2(value int64) UniformCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Generates labels for candidate sampling with a uniform distribution.
-//
-// See explanations of candidate sampling and the data formats at
-// go/candidate-sampling.
-//
-// For each batch, this op picks a single set of sampled candidate labels.
-//
-// The advantages of sampling candidates per-batch are simplicity and the
-// possibility of efficient dense matrix multiplication. The disadvantage is that
-// the sampled candidates must be chosen independently of the context and of the
-// true labels.
-//
-// Arguments:
-// true_classes: A batch_size * num_true matrix, in which each row contains the
-// IDs of the num_true target_classes in the corresponding original label.
-// num_true: Number of true labels per context.
-// num_sampled: Number of candidates to randomly sample.
-// unique: If unique is true, we sample with rejection, so that all sampled
-// candidates in a batch are unique. This requires some approximation to
-// estimate the post-rejection sampling probabilities.
-// range_max: The sampler will sample integers from the interval [0, range_max).
-//
-// Returns A vector of length num_sampled, in which each element is
-// the ID of a sampled candidate.A batch_size * num_true matrix, representing
-// the number of times each candidate is expected to occur in a batch
-// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
-// candidate representing the number of times the candidate is expected
-// to occur in a batch of sampled candidates. If unique=true, then this is a
-// probability.
-func UniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...UniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "UniformCandidateSampler",
- Input: []tf.Input{
- true_classes,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// CTCLossAttr is an optional argument to CTCLoss.
type CTCLossAttr func(optionalAttr)
@@ -31300,621 +33141,3 @@ func Switch(scope *Scope, data tf.Output, pred tf.Output) (output_false tf.Outpu
op := scope.AddOperation(opspec)
return op.Output(0), op.Output(1)
}
-
-// Add all input tensors element wise.
-//
-// Arguments:
-// inputs: Must all be the same size and shape.
-func AddN(scope *Scope, inputs []tf.Output) (sum tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "AddN",
- Input: []tf.Input{
- tf.OutputList(inputs),
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// TryRpcAttr is an optional argument to TryRpc.
-type TryRpcAttr func(optionalAttr)
-
-// TryRpcProtocol sets the optional protocol attribute to value.
-//
-// value: RPC protocol to use. Empty string means use the default protocol.
-// Options include 'grpc'.
-// If not specified, defaults to ""
-func TryRpcProtocol(value string) TryRpcAttr {
- return func(m optionalAttr) {
- m["protocol"] = value
- }
-}
-
-// TryRpcFailFast sets the optional fail_fast attribute to value.
-//
-// value: `boolean`. If `true` (default), then failures to connect
-// (i.e., the server does not immediately respond) cause an RPC failure.
-// If not specified, defaults to true
-func TryRpcFailFast(value bool) TryRpcAttr {
- return func(m optionalAttr) {
- m["fail_fast"] = value
- }
-}
-
-// TryRpcTimeoutInMs sets the optional timeout_in_ms attribute to value.
-//
-// value: `int`. If `0` (default), then the kernel will run the RPC
-// request and only time out if the RPC deadline passes or the session times out.
-// If this value is greater than `0`, then the op will raise an exception if
-// the RPC takes longer than `timeout_in_ms`.
-// If not specified, defaults to 0
-func TryRpcTimeoutInMs(value int64) TryRpcAttr {
- return func(m optionalAttr) {
- m["timeout_in_ms"] = value
- }
-}
-
-// Perform batches of RPC requests.
-//
-// This op asynchronously performs either a single RPC request, or a batch
-// of requests. RPC requests are defined by three main parameters:
-//
-// - `address` (the host+port or BNS address of the request)
-// - `method` (the method name for the request)
-// - `request` (the serialized proto string, or vector of strings,
-// of the RPC request argument).
-//
-// For example, if you have an RPC service running on port localhost:2345,
-// and its interface is configured with the following proto declaration:
-//
-// ```
-// service MyService {
-// rpc MyMethod(MyRequestProto) returns (MyResponseProto) {
-// }
-// };
-// ```
-//
-// then call this op with arguments:
-//
-// ```
-// address = "localhost:2345"
-// method = "MyService/MyMethod"
-// ```
-//
-// The `request` tensor is a string tensor representing serialized `MyRequestProto`
-// strings; and the output string tensor `response` will have the same shape
-// and contain (upon successful completion) corresponding serialized
-// `MyResponseProto` strings.
-//
-// For example, to send a single, empty, `MyRequestProto`, call
-// this op with `request = ""`. To send 5 **parallel** empty requests,
-// call this op with `request = ["", "", "", "", ""]`.
-//
-// More generally, one can create a batch of `MyRequestProto` serialized protos
-// from regular batched tensors using the `encode_proto` op, and convert
-// the response `MyResponseProto` serialized protos to batched tensors
-// using the `decode_proto` op.
-//
-// **NOTE** Working with serialized proto strings is faster than instantiating
-// actual proto objects in memory, so no performance degradation is expected
-// compared to writing custom kernels for this workflow.
-//
-// Unlike the standard `Rpc` op, if the connection fails or the remote worker
-// returns an error status, this op does **not** reraise the exception.
-// Instead, the `status_code` and `status_message` entry for the corresponding RPC
-// call is set with the error returned from the RPC call. The `response` tensor
-// will contain valid response values for those minibatch entries whose RPCs did
-// not fail; the rest of the entries will have empty strings.
-//
-// Arguments:
-// address: `0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server.
-// If this tensor has more than 1 element, then multiple parallel rpc requests
-// are sent. This argument broadcasts with `method` and `request`.
-// method: `0-D` or `1-D`. The method address on the RPC server.
-// If this tensor has more than 1 element, then multiple parallel rpc requests
-// are sent. This argument broadcasts with `address` and `request`.
-// request: `0-D` or `1-D`. Serialized proto strings: the rpc request argument.
-// If this tensor has more than 1 element, then multiple parallel rpc requests
-// are sent. This argument broadcasts with `address` and `method`.
-//
-// Returns Same shape as `request`. Serialized proto strings: the rpc responses.Same shape as `request`. Values correspond to tensorflow Status enum codes.Same shape as `request`. Values correspond to Status messages
-// returned from the RPC calls.
-func TryRpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output, optional ...TryRpcAttr) (response tf.Output, status_code tf.Output, status_message tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "TryRpc",
- Input: []tf.Input{
- address, method, request,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
-// EnterAttr is an optional argument to Enter.
-type EnterAttr func(optionalAttr)
-
-// EnterIsConstant sets the optional is_constant attribute to value.
-//
-// value: If true, the output is constant within the child frame.
-// If not specified, defaults to false
-func EnterIsConstant(value bool) EnterAttr {
- return func(m optionalAttr) {
- m["is_constant"] = value
- }
-}
-
-// EnterParallelIterations sets the optional parallel_iterations attribute to value.
-//
-// value: The number of iterations allowed to run in parallel.
-// If not specified, defaults to 10
-func EnterParallelIterations(value int64) EnterAttr {
- return func(m optionalAttr) {
- m["parallel_iterations"] = value
- }
-}
-
-// Creates or finds a child frame, and makes `data` available to the child frame.
-//
-// This op is used together with `Exit` to create loops in the graph.
-// The unique `frame_name` is used by the `Executor` to identify frames. If
-// `is_constant` is true, `output` is a constant in the child frame; otherwise
-// it may be changed in the child frame. At most `parallel_iterations` iterations
-// are run in parallel in the child frame.
-//
-// Arguments:
-// data: The tensor to be made available to the child frame.
-// frame_name: The name of the child frame.
-//
-// Returns The same tensor as `data`.
-func Enter(scope *Scope, data tf.Output, frame_name string, optional ...EnterAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"frame_name": frame_name}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Enter",
- Input: []tf.Input{
- data,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Produce a string tensor that encodes the state of a Reader.
-//
-// Not all Readers support being serialized, so this can produce an
-// Unimplemented error.
-//
-// Arguments:
-// reader_handle: Handle to a Reader.
-func ReaderSerializeStateV2(scope *Scope, reader_handle tf.Output) (state tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ReaderSerializeStateV2",
- Input: []tf.Input{
- reader_handle,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Exits the current frame to its parent frame.
-//
-// Exit makes its input `data` available to the parent frame.
-//
-// Arguments:
-// data: The tensor to be made available to the parent frame.
-//
-// Returns The same tensor as `data`.
-func Exit(scope *Scope, data tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Exit",
- Input: []tf.Input{
- data,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns a copy of the input tensor.
-func Snapshot(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Snapshot",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns a tensor of zeros with the same shape and type as x.
-//
-// Arguments:
-// x: a tensor of type T.
-//
-// Returns a tensor of the same shape and type as x but filled with zeros.
-func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ZerosLike",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// AbortAttr is an optional argument to Abort.
-type AbortAttr func(optionalAttr)
-
-// AbortErrorMsg sets the optional error_msg attribute to value.
-//
-// value: A string which is the message associated with the exception.
-// If not specified, defaults to ""
-func AbortErrorMsg(value string) AbortAttr {
- return func(m optionalAttr) {
- m["error_msg"] = value
- }
-}
-
-// AbortExitWithoutError sets the optional exit_without_error attribute to value.
-// If not specified, defaults to false
-func AbortExitWithoutError(value bool) AbortAttr {
- return func(m optionalAttr) {
- m["exit_without_error"] = value
- }
-}
-
-// Raise a exception to abort the process when called.
-//
-// If exit_without_error is true, the process will exit normally,
-// otherwise it will exit with a SIGABORT signal.
-//
-// Returns nothing but an exception.
-//
-// Returns the created operation.
-func Abort(scope *Scope, optional ...AbortAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Abort",
-
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
-// FixedUnigramCandidateSamplerAttr is an optional argument to FixedUnigramCandidateSampler.
-type FixedUnigramCandidateSamplerAttr func(optionalAttr)
-
-// FixedUnigramCandidateSamplerVocabFile sets the optional vocab_file attribute to value.
-//
-// value: Each valid line in this file (which should have a CSV-like format)
-// corresponds to a valid word ID. IDs are in sequential order, starting from
-// num_reserved_ids. The last entry in each line is expected to be a value
-// corresponding to the count or relative probability. Exactly one of vocab_file
-// and unigrams needs to be passed to this op.
-// If not specified, defaults to ""
-func FixedUnigramCandidateSamplerVocabFile(value string) FixedUnigramCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["vocab_file"] = value
- }
-}
-
-// FixedUnigramCandidateSamplerDistortion sets the optional distortion attribute to value.
-//
-// value: The distortion is used to skew the unigram probability distribution.
-// Each weight is first raised to the distortion's power before adding to the
-// internal unigram distribution. As a result, distortion = 1.0 gives regular
-// unigram sampling (as defined by the vocab file), and distortion = 0.0 gives
-// a uniform distribution.
-// If not specified, defaults to 1
-func FixedUnigramCandidateSamplerDistortion(value float32) FixedUnigramCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["distortion"] = value
- }
-}
-
-// FixedUnigramCandidateSamplerNumReservedIds sets the optional num_reserved_ids attribute to value.
-//
-// value: Optionally some reserved IDs can be added in the range [0,
-// ..., num_reserved_ids) by the users. One use case is that a special unknown
-// word token is used as ID 0. These IDs will have a sampling probability of 0.
-// If not specified, defaults to 0
-func FixedUnigramCandidateSamplerNumReservedIds(value int64) FixedUnigramCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["num_reserved_ids"] = value
- }
-}
-
-// FixedUnigramCandidateSamplerNumShards sets the optional num_shards attribute to value.
-//
-// value: A sampler can be used to sample from a subset of the original range
-// in order to speed up the whole computation through parallelism. This parameter
-// (together with 'shard') indicates the number of partitions that are being
-// used in the overall computation.
-// If not specified, defaults to 1
-//
-// REQUIRES: value >= 1
-func FixedUnigramCandidateSamplerNumShards(value int64) FixedUnigramCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["num_shards"] = value
- }
-}
-
-// FixedUnigramCandidateSamplerShard sets the optional shard attribute to value.
-//
-// value: A sampler can be used to sample from a subset of the original range
-// in order to speed up the whole computation through parallelism. This parameter
-// (together with 'num_shards') indicates the particular partition number of a
-// sampler op, when partitioning is being used.
-// If not specified, defaults to 0
-//
-// REQUIRES: value >= 0
-func FixedUnigramCandidateSamplerShard(value int64) FixedUnigramCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["shard"] = value
- }
-}
-
-// FixedUnigramCandidateSamplerUnigrams sets the optional unigrams attribute to value.
-//
-// value: A list of unigram counts or probabilities, one per ID in sequential
-// order. Exactly one of vocab_file and unigrams should be passed to this op.
-// If not specified, defaults to <>
-func FixedUnigramCandidateSamplerUnigrams(value []float32) FixedUnigramCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["unigrams"] = value
- }
-}
-
-// FixedUnigramCandidateSamplerSeed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func FixedUnigramCandidateSamplerSeed(value int64) FixedUnigramCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// FixedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value.
-//
-// value: An second seed to avoid seed collision.
-// If not specified, defaults to 0
-func FixedUnigramCandidateSamplerSeed2(value int64) FixedUnigramCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Generates labels for candidate sampling with a learned unigram distribution.
-//
-// A unigram sampler could use a fixed unigram distribution read from a
-// file or passed in as an in-memory array instead of building up the distribution
-// from data on the fly. There is also an option to skew the distribution by
-// applying a distortion power to the weights.
-//
-// The vocabulary file should be in CSV-like format, with the last field
-// being the weight associated with the word.
-//
-// For each batch, this op picks a single set of sampled candidate labels.
-//
-// The advantages of sampling candidates per-batch are simplicity and the
-// possibility of efficient dense matrix multiplication. The disadvantage is that
-// the sampled candidates must be chosen independently of the context and of the
-// true labels.
-//
-// Arguments:
-// true_classes: A batch_size * num_true matrix, in which each row contains the
-// IDs of the num_true target_classes in the corresponding original label.
-// num_true: Number of true labels per context.
-// num_sampled: Number of candidates to randomly sample.
-// unique: If unique is true, we sample with rejection, so that all sampled
-// candidates in a batch are unique. This requires some approximation to
-// estimate the post-rejection sampling probabilities.
-// range_max: The sampler will sample integers from the interval [0, range_max).
-//
-// Returns A vector of length num_sampled, in which each element is
-// the ID of a sampled candidate.A batch_size * num_true matrix, representing
-// the number of times each candidate is expected to occur in a batch
-// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
-// candidate representing the number of times the candidate is expected
-// to occur in a batch of sampled candidates. If unique=true, then this is a
-// probability.
-func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...FixedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "FixedUnigramCandidateSampler",
- Input: []tf.Input{
- true_classes,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
-// Transforms a tf.Example proto (as a string) into typed tensors.
-//
-// Arguments:
-// serialized: A vector containing a batch of binary serialized Example protos.
-// dense_defaults: A list of Tensors (some may be empty), whose length matches
-// the length of `dense_keys`. dense_defaults[j] provides default values
-// when the example's feature_map lacks dense_key[j]. If an empty Tensor is
-// provided for dense_defaults[j], then the Feature dense_keys[j] is required.
-// The input type is inferred from dense_defaults[j], even when it's empty.
-// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined,
-// then the shape of dense_defaults[j] must match that of dense_shapes[j].
-// If dense_shapes[j] has an undefined major dimension (variable strides dense
-// feature), dense_defaults[j] must contain a single element:
-// the padding element.
-// num_sparse: The number of sparse features to be parsed from the example. This
-// must match the lengths of `sparse_keys` and `sparse_types`.
-// sparse_keys: A list of `num_sparse` strings.
-// The keys expected in the Examples' features associated with sparse values.
-// dense_keys: The keys expected in the Examples' features associated with dense
-// values.
-// sparse_types: A list of `num_sparse` types; the data types of data in each
-// Feature given in sparse_keys.
-// Currently the ParseSingleExample op supports DT_FLOAT (FloatList),
-// DT_INT64 (Int64List), and DT_STRING (BytesList).
-// dense_shapes: The shapes of data in each Feature given in dense_keys.
-// The length of this list must match the length of `dense_keys`. The
-// number of elements in the Feature corresponding to dense_key[j] must
-// always equal dense_shapes[j].NumEntries(). If dense_shapes[j] ==
-// (D0, D1, ..., DN) then the shape of output Tensor dense_values[j]
-// will be (D0, D1, ..., DN): In the case dense_shapes[j] = (-1, D1,
-// ..., DN), the shape of the output Tensor dense_values[j] will be (M,
-// D1, .., DN), where M is the number of blocks of elements of length
-// D1 * .... * DN, in the input.
-func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.Output, num_sparse int64, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_sparse": num_sparse, "sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes}
- opspec := tf.OpSpec{
- Type: "ParseSingleExample",
- Input: []tf.Input{
- serialized, tf.OutputList(dense_defaults),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil {
- scope.UpdateErr("ParseSingleExample", err)
- return
- }
- if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil {
- scope.UpdateErr("ParseSingleExample", err)
- return
- }
- if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil {
- scope.UpdateErr("ParseSingleExample", err)
- return
- }
- if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil {
- scope.UpdateErr("ParseSingleExample", err)
- return
- }
- return sparse_indices, sparse_values, sparse_shapes, dense_values
-}
-
-// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2.
-type WholeFileReaderV2Attr func(optionalAttr)
-
-// WholeFileReaderV2Container sets the optional container attribute to value.
-//
-// value: If non-empty, this reader is placed in the given container.
-// Otherwise, a default container is used.
-// If not specified, defaults to ""
-func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// WholeFileReaderV2SharedName sets the optional shared_name attribute to value.
-//
-// value: If non-empty, this reader is named in the given bucket
-// with this shared_name. Otherwise, the node name is used instead.
-// If not specified, defaults to ""
-func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// A Reader that outputs the entire contents of a file as a value.
-//
-// To use, enqueue filenames in a Queue. The output of ReaderRead will
-// be a filename (key) and the contents of that file (value).
-//
-// Returns The handle to reference the Reader.
-func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "WholeFileReaderV2",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Deserializes a serialized tree ensemble config and replaces current tree
-//
-// ensemble.
-//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble.
-// stamp_token: Token to use as the new value of the resource stamp.
-// tree_ensemble_serialized: Serialized proto of the ensemble.
-//
-// Returns the created operation.
-func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BoostedTreesDeserializeEnsemble",
- Input: []tf.Input{
- tree_ensemble_handle, stamp_token, tree_ensemble_serialized,
- },
- }
- return scope.AddOperation(opspec)
-}
diff --git a/tensorflow/java/README.md b/tensorflow/java/README.md
index c7382ff231..7ef862ae79 100644
--- a/tensorflow/java/README.md
+++ b/tensorflow/java/README.md
@@ -10,7 +10,7 @@
## Quickstart
-- Refer to [Installing TensorFlow for Java](https://www.tensorflow.org/install/install_java)
+- Refer to [Installing TensorFlow for Java](https://www.tensorflow.org/install/lang_java)
- [Javadoc](https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/package-summary)
- [![Maven Central](https://maven-badges.herokuapp.com/maven-central/org.tensorflow/tensorflow/badge.svg)](https://maven-badges.herokuapp.com/maven-central/org.tensorflow/tensorflow)
@@ -22,8 +22,7 @@ native libraries will need to be built from source.
1. Install [bazel](https://www.bazel.build/versions/master/docs/install.html)
2. Setup the environment to build TensorFlow from source code
- ([Linux](https://www.tensorflow.org/install/install_sources#PrepareLinux)
- or [macOS](https://www.tensorflow.org/install/install_sources#PrepareMac)).
+ ([Linux or macOS](https://www.tensorflow.org/install/source)).
If you'd like to skip reading those details and do not care about GPU
support, try the following:
@@ -35,7 +34,7 @@ native libraries will need to be built from source.
brew install swig
```
-3. [Configure](https://www.tensorflow.org/install/install_sources#configure_the_installation)
+3. [Configure](https://www.tensorflow.org/install/source)
(e.g., enable GPU support) and build:
```sh
diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml
index f9093ce385..6b3e305e5d 100644
--- a/tensorflow/java/maven/libtensorflow/pom.xml
+++ b/tensorflow/java/maven/libtensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.10.0</version>
+ <version>1.11.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml
index 1208956dec..f130515934 100644
--- a/tensorflow/java/maven/libtensorflow_jni/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.10.0</version>
+ <version>1.11.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
index 755449cb3c..67ecc2d597 100644
--- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.10.0</version>
+ <version>1.11.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni_gpu</artifactId>
diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml
index e1bf2c7dba..8ba859da01 100644
--- a/tensorflow/java/maven/pom.xml
+++ b/tensorflow/java/maven/pom.xml
@@ -6,7 +6,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.10.0</version>
+ <version>1.11.0</version>
<packaging>pom</packaging>
<url>https://www.tensorflow.org</url>
diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml
index b89f042567..dcd654d713 100644
--- a/tensorflow/java/maven/proto/pom.xml
+++ b/tensorflow/java/maven/proto/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.10.0</version>
+ <version>1.11.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>proto</artifactId>
diff --git a/tensorflow/java/maven/spark-tensorflow-connector/pom.xml b/tensorflow/java/maven/spark-tensorflow-connector/pom.xml
index 1b7995be2c..45214f834c 100644
--- a/tensorflow/java/maven/spark-tensorflow-connector/pom.xml
+++ b/tensorflow/java/maven/spark-tensorflow-connector/pom.xml
@@ -6,7 +6,7 @@
<groupId>org.tensorflow</groupId>
<artifactId>spark-tensorflow-connector_2.11</artifactId>
<packaging>jar</packaging>
- <version>1.10.0</version>
+ <version>1.11.0</version>
<name>spark-tensorflow-connector</name>
<url>https://www.tensorflow.org</url>
<description>TensorFlow TFRecord connector for Apache Spark DataFrames</description>
diff --git a/tensorflow/java/maven/tensorflow-hadoop/pom.xml b/tensorflow/java/maven/tensorflow-hadoop/pom.xml
index 0fe6f4dce4..a8669ee72b 100644
--- a/tensorflow/java/maven/tensorflow-hadoop/pom.xml
+++ b/tensorflow/java/maven/tensorflow-hadoop/pom.xml
@@ -5,7 +5,7 @@
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-hadoop</artifactId>
<packaging>jar</packaging>
- <version>1.10.0</version>
+ <version>1.11.0</version>
<name>tensorflow-hadoop</name>
<url>https://www.tensorflow.org</url>
<description>TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop</description>
diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml
index 0de90244b1..67d628ba11 100644
--- a/tensorflow/java/maven/tensorflow/pom.xml
+++ b/tensorflow/java/maven/tensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.10.0</version>
+ <version>1.11.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>tensorflow</artifactId>
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 2dc2808152..9275ad767e 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -333,6 +333,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//third_party/python_runtime:headers",
+ "@com_google_absl//absl/memory",
],
)
@@ -1638,6 +1639,15 @@ tf_gen_op_wrapper_private_py(
)
tf_gen_op_wrapper_private_py(
+ name = "experimental_dataset_ops_gen",
+ visibility = [
+ "//learning/brain/python/ops:__pkg__",
+ "//tensorflow:__subpackages__",
+ "//tensorflow/python/kernel_tests:__pkg__",
+ ],
+)
+
+tf_gen_op_wrapper_private_py(
name = "image_ops_gen",
visibility = ["//learning/brain/python/ops:__pkg__"],
)
@@ -1998,6 +2008,30 @@ py_library(
)
py_library(
+ name = "while_v2",
+ srcs = [
+ "ops/while_v2.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":array_ops",
+ ":cond_v2_impl",
+ ":constant_op",
+ ":control_flow_ops",
+ ":control_flow_util",
+ ":framework_ops",
+ ":function_def_to_graph",
+ ":functional_ops_gen",
+ ":gradients_impl",
+ ":list_ops",
+ ":tensor_shape",
+ ":util",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/eager:function",
+ ],
+)
+
+py_library(
name = "cond_v2_impl",
srcs = [
"ops/cond_v2_impl.py",
@@ -2301,6 +2335,8 @@ py_library(
deps = [
":framework_for_generated_wrappers",
":logging_ops_gen",
+ ":platform",
+ ":string_ops",
":util",
],
)
@@ -3090,7 +3126,7 @@ cuda_py_test(
cuda_py_test(
name = "image_grad_test",
- size = "small",
+ size = "medium",
srcs = ["ops/image_grad_test.py"],
additional_deps = [
":client_testlib",
@@ -3738,6 +3774,19 @@ cuda_py_tests(
],
)
+cc_library(
+ name = "session_ref",
+ srcs = ["client/session_ref.cc"],
+ hdrs = ["client/session_ref.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:master_proto_cc",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:replay_log_proto_cc",
+ ],
+)
+
tf_cuda_library(
name = "tf_session_helper",
srcs = ["client/tf_session_helper.cc"],
@@ -3748,6 +3797,7 @@ tf_cuda_library(
":ndarray_tensor_bridge",
":numpy_lib",
":safe_ptr",
+ ":session_ref",
":test_ops_kernels",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
@@ -3760,7 +3810,6 @@ tf_cuda_library(
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/core:session_ref",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
],
diff --git a/tensorflow/python/autograph/README.md b/tensorflow/python/autograph/README.md
index cc54da4daa..bfe21b4765 100644
--- a/tensorflow/python/autograph/README.md
+++ b/tensorflow/python/autograph/README.md
@@ -65,7 +65,7 @@ pip install -U tf-nightly
Then import the `autograph` module from `tf.contrib`:
```
-from tensorflow.contrib import autograph as ag
+from tensorflow.python import autograph as ag
```
### Related links
diff --git a/tensorflow/python/autograph/__init__.py b/tensorflow/python/autograph/__init__.py
index c3448e6e58..5ed5e85158 100644
--- a/tensorflow/python/autograph/__init__.py
+++ b/tensorflow/python/autograph/__init__.py
@@ -27,6 +27,7 @@ from tensorflow.python.autograph import utils
from tensorflow.python.autograph.core.errors import GraphConstructionError
from tensorflow.python.autograph.core.errors import TfRuntimeError
from tensorflow.python.autograph.core.errors import improved_errors
+from tensorflow.python.autograph.impl.api import ConversionOptions
from tensorflow.python.autograph.impl.api import RunMode
from tensorflow.python.autograph.impl.api import convert
from tensorflow.python.autograph.impl.api import converted_call
@@ -42,6 +43,7 @@ from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
# Main API
+ 'ConversionOptions',
'RunMode',
'convert',
'converted_call',
diff --git a/tensorflow/python/autograph/converters/builtin_functions.py b/tensorflow/python/autograph/converters/builtin_functions.py
index b8b268d8ce..583c978395 100644
--- a/tensorflow/python/autograph/converters/builtin_functions.py
+++ b/tensorflow/python/autograph/converters/builtin_functions.py
@@ -48,8 +48,13 @@ class BuiltinFunctionTransformer(converter.Base):
node = self.generic_visit(node)
if anno.hasanno(node.func, 'live_val'):
live_val = anno.getanno(node.func, 'live_val')
- if live_val in py_builtins.SUPPORTED_BUILTINS:
- node = self._convert_builtin(live_val, node.args, as_expression=True)
+ try:
+ if live_val in py_builtins.SUPPORTED_BUILTINS:
+ node = self._convert_builtin(live_val, node.args, as_expression=True)
+ except TypeError:
+ # Not everything in Python is hashable. If it isn't then it's definitely
+ # not a supported built-in.
+ return node
return node
def visit_Print(self, node):
diff --git a/tensorflow/python/autograph/converters/builtin_functions_test.py b/tensorflow/python/autograph/converters/builtin_functions_test.py
index c87c304cdb..2ed14c14e7 100644
--- a/tensorflow/python/autograph/converters/builtin_functions_test.py
+++ b/tensorflow/python/autograph/converters/builtin_functions_test.py
@@ -36,7 +36,7 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
return len(a)
with self.converted(test_fn, builtin_functions, {'len': len}) as result:
- with self.cached_session() as sess:
+ with self.test_session() as sess:
p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
ops = result.test_fn(p)
self.assertEqual(sess.run(ops, {p: [0, 0, 0]}), 3)
@@ -50,7 +50,7 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
return print(a)
with self.converted(test_fn, builtin_functions, {'print': print}) as result:
- with self.cached_session() as sess:
+ with self.test_session() as sess:
with self.assertPrints('a\n'):
sess.run(result.test_fn('a'))
@@ -63,12 +63,22 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
return print(a, b, c)
with self.converted(test_fn, builtin_functions, {'print': print}) as result:
- with self.cached_session() as sess:
+ with self.test_session() as sess:
with self.assertPrints('a 1 [2, 3]\n'):
sess.run(
result.test_fn(
constant_op.constant('a'), constant_op.constant(1), [2, 3]))
+ def test_conversion_robust_to_unhashable_callables(self):
+
+ def test_fn():
+ return foo() # pylint:disable=undefined-variable
+
+ with self.converted(test_fn, builtin_functions, {'foo': {
+ 'a': 'b'
+ }.keys}) as result:
+ self.assertListEqual(list(result.test_fn()), ['a'])
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py
index 6a606c450d..fc2075b781 100644
--- a/tensorflow/python/autograph/converters/call_trees.py
+++ b/tensorflow/python/autograph/converters/call_trees.py
@@ -238,9 +238,16 @@ class CallTreeTransformer(converter.Base):
# Before we could convert all the time though, we'd need a reasonable
# caching mechanism.
template = """
- ag__.converted_call(func, True, False, False, {}, args)
+ ag__.converted_call(
+ func,
+ ag__.ConversionOptions.new(recursive=recursive_val),
+ args)
"""
- call_expr = templates.replace(template, func=node.func, args=node.args)
+ call_expr = templates.replace(
+ template,
+ func=node.func,
+ recursive_val=parser.parse_expression(str(self.ctx.program.recursive)),
+ args=node.args)
new_call = call_expr[0].value
# TODO(mdan): Improve the template mechanism to better support this.
new_call.keywords = node.keywords
diff --git a/tensorflow/python/autograph/converters/return_statements.py b/tensorflow/python/autograph/converters/return_statements.py
index 62da045d6a..496c99e3b5 100644
--- a/tensorflow/python/autograph/converters/return_statements.py
+++ b/tensorflow/python/autograph/converters/return_statements.py
@@ -212,6 +212,7 @@ class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor):
def __init__(self):
self.cant_return = False
+ self.function_level = 0
super(DetectReturnInUnsupportedControlFlow, self).__init__()
def visit_While(self, node):
@@ -229,6 +230,12 @@ class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor):
self.generic_visit(node)
self.cant_return = False
+ def visit_FunctionDef(self, node):
+ if not self.function_level:
+ self.function_level += 1
+ self.generic_visit(node)
+ self.function_level -= 1
+
def visit_Return(self, node):
if self.cant_return:
raise ValueError(
@@ -242,6 +249,7 @@ class DetectReturnInConditional(gast.NodeVisitor):
def __init__(self):
self.cant_return = False
+ self.function_level = 0
super(DetectReturnInConditional, self).__init__()
def visit_If(self, node):
@@ -249,6 +257,12 @@ class DetectReturnInConditional(gast.NodeVisitor):
self.generic_visit(node)
self.cant_return = False
+ def visit_FunctionDef(self, node):
+ if not self.function_level:
+ self.function_level += 1
+ self.generic_visit(node)
+ self.function_level -= 1
+
def visit_Return(self, node):
if self.cant_return:
raise ValueError(
diff --git a/tensorflow/python/autograph/converters/return_statements_test.py b/tensorflow/python/autograph/converters/return_statements_test.py
index 01dd03da0b..762fbc6f60 100644
--- a/tensorflow/python/autograph/converters/return_statements_test.py
+++ b/tensorflow/python/autograph/converters/return_statements_test.py
@@ -151,6 +151,18 @@ class SingleReturnTest(converter_testing.TestCase):
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, -2)
+ def test_nested_functions_in_control_flow(self):
+
+ def test_fn(x):
+
+ if x:
+ def inner_fn(y):
+ return y
+ inner_fn(x)
+
+ self.assertTransformedEquivalent(test_fn, 2)
+ self.assertTransformedEquivalent(test_fn, -2)
+
def test_loop(self):
def test_fn(x):
diff --git a/tensorflow/python/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py
index 7b3905fdee..80928ae7f4 100644
--- a/tensorflow/python/autograph/core/converter.py
+++ b/tensorflow/python/autograph/core/converter.py
@@ -63,10 +63,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
from enum import Enum
-
from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import naming
from tensorflow.python.autograph.pyct import anno
@@ -129,9 +127,8 @@ class ProgramContext(object):
self.autograph_module = autograph_module
self.uncompiled_modules = uncompiled_modules
- # Required to output dependencies in discovery order, which should match
- # the reverse dependency order.
- self.dependency_cache = collections.OrderedDict()
+ self.conversion_order = []
+ self.dependency_cache = {}
self.additional_imports = set()
self.name_map = {}
@@ -177,6 +174,7 @@ class ProgramContext(object):
self.name_map[o] = name
def add_to_cache(self, original_entity, converted_ast):
+ self.conversion_order.append(original_entity)
self.dependency_cache[original_entity] = converted_ast
diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py
index 0a0c6f9002..7ce1b7c4c5 100644
--- a/tensorflow/python/autograph/core/converter_testing.py
+++ b/tensorflow/python/autograph/core/converter_testing.py
@@ -93,11 +93,21 @@ class TestCase(test.TestCase):
self.dynamic_calls.append(args)
return 7
+ class ConversionOptions(object):
+ """Mock version of api.ConversionOptions."""
+
+ def __init__(self, recursive):
+ self.recursive = recursive
+
+ @classmethod
+ def new(cls, recursive):
+ cls(recursive)
+
try:
result, source = compiler.ast_to_object(node, include_source_map=True)
result.tf = self.make_fake_mod('fake_tf', *symbols)
- fake_ag = self.make_fake_mod('fake_ag', converted_call)
+ fake_ag = self.make_fake_mod('fake_ag', converted_call, ConversionOptions)
fake_ag.__dict__.update(operators.__dict__)
fake_ag.__dict__['utils'] = utils
fake_ag.__dict__['rewrite_graph_construction_error'] = (
diff --git a/tensorflow/python/autograph/core/errors.py b/tensorflow/python/autograph/core/errors.py
index 0750353423..23f8c5b52b 100644
--- a/tensorflow/python/autograph/core/errors.py
+++ b/tensorflow/python/autograph/core/errors.py
@@ -208,7 +208,6 @@ def rewrite_tf_runtime_error(error, source_map):
"""
try:
cleaned_traceback = _cut_traceback_loops(source_map, error.op.traceback)
- # cleaned_traceback = error.op.traceback
cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback)
op_name = error.op.name
diff --git a/tensorflow/python/autograph/core/errors_test.py b/tensorflow/python/autograph/core/errors_test.py
index 0444ed7eab..aa6c293268 100644
--- a/tensorflow/python/autograph/core/errors_test.py
+++ b/tensorflow/python/autograph/core/errors_test.py
@@ -54,7 +54,7 @@ class RuntimeErrorsTest(test.TestCase):
ops = zero_div_caller()
with self.assertRaises(errors.TfRuntimeError) as cm:
with errors.improved_errors(zero_div_caller):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(ops)
for frame in cm.exception.custom_traceback:
@@ -69,7 +69,7 @@ class RuntimeErrorsTest(test.TestCase):
ops = zero_div_caller()
with self.assertRaises(errors.TfRuntimeError) as cm:
with errors.improved_errors(zero_div_caller):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(ops)
all_function_names = set()
@@ -86,7 +86,7 @@ class RuntimeErrorsTest(test.TestCase):
ops = zero_div_caller()
with self.assertRaises(tf_errors.InvalidArgumentError):
with errors.improved_errors(zero_div_caller):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(ops)
def test_improved_errors_validation(self):
diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py
index 669d36bd28..1dc97d2331 100644
--- a/tensorflow/python/autograph/impl/api.py
+++ b/tensorflow/python/autograph/impl/api.py
@@ -18,7 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from functools import wraps
+import collections
+import functools
from enum import Enum
@@ -38,6 +39,41 @@ from tensorflow.python.util import tf_inspect
# (currently we require (module + class name, type))
+class ConversionOptions(
+ collections.namedtuple('ConversionOptions',
+ ('recursive', 'verbose', 'strip_decorators',
+ 'force_conversion', 'arg_types'))):
+ """Container for conversion flags.
+
+ Attributes:
+ recursive: bool, whether to recursively convert any user functions or
+ classes that the converted function may use.
+ verbose: bool, whether to log the compiled code.
+ strip_decorators: Tuple[Callable], contains decorators that should be in
+ excluded from the compiled output. By default, when converting a
+ function before the decorators are applied, the compiled output will
+ include those decorators.
+ force_conversion: bool, whether to force convertinng the target entity.
+ When force_conversion is turned off, the converter may decide to
+ return the function as-is.
+ arg_types: Optional[Dict[Text, Type]], type hints for symbols including
+ function arguments.
+ """
+
+ @classmethod
+ def new(cls,
+ recursive=False,
+ verbose=False,
+ strip_decorators=None,
+ force_conversion=False,
+ arg_types=None):
+ return cls(recursive=recursive,
+ verbose=verbose,
+ strip_decorators=strip_decorators or (),
+ force_conversion=force_conversion,
+ arg_types=arg_types or {})
+
+
# TODO(mdan): This should behave like to_graph (e.g. convert statically).
def convert(recursive=False, verbose=False):
"""Decorator that compiles a function to use TensorFlow ops.
@@ -59,9 +95,15 @@ def convert(recursive=False, verbose=False):
def decorator(f):
"""Decorator implementation."""
- @wraps(f)
+ @functools.wraps(f)
def wrapper(*args, **kwargs):
- return converted_call(f, recursive, verbose, True, {}, *args, **kwargs)
+ return converted_call(
+ f,
+ ConversionOptions.new(
+ recursive=recursive,
+ verbose=verbose,
+ force_conversion=True,
+ ), *args, **kwargs)
wrapper = tf_decorator.make_decorator(f, wrapper)
@@ -107,11 +149,11 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None):
def decorator(f):
"""Decorator implementation."""
- @wraps(f)
+ @functools.wraps(f)
def graph_wrapper(*args, **kwargs):
return f(*args, **kwargs)
- @wraps(f)
+ @functools.wraps(f)
def py_func_wrapper(*args, **kwargs):
if kwargs:
raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs')
@@ -135,12 +177,11 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None):
# TODO(mdan): Move to a private, undocumented module.
-def converted_call(f, recursive, verbose, force_conversion, arg_types, *args,
- **kwargs):
+def converted_call(f, options, *args, **kwargs):
"""Compiles a function call inline. For internal use only."""
# TODO(mdan): This needs cleanup.
# In particular, we may want to avoid renaming functions altogether.
- if not force_conversion and conversion.is_whitelisted_for_graph(f):
+ if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
return f(*args, **kwargs)
unknown_arg_value = object() # Sentinel for arguments of unknown value
@@ -183,8 +224,8 @@ def converted_call(f, recursive, verbose, force_conversion, arg_types, *args,
continue
arg_class = arg.__class__
# If arg_value_hints specifies any name, use that instead.
- if name not in arg_types:
- arg_types[name] = (arg_class.__name__, arg_class)
+ if name not in options.arg_types:
+ options.arg_types[name] = (arg_class.__name__, arg_class)
# When called from within a decorator, this is the only indication that
# the function is a method - it appears that the decorator is applied
@@ -199,23 +240,25 @@ def converted_call(f, recursive, verbose, force_conversion, arg_types, *args,
converted_f = to_graph(
target_entity,
- recursive=recursive,
- verbose=verbose,
+ recursive=options.recursive,
+ verbose=options.verbose,
arg_values=arg_values,
- arg_types=arg_types,
- partial_types=partial_types)
+ arg_types=options.arg_types,
+ partial_types=partial_types,
+ strip_decorators=options.strip_decorators)
return converted_f(*effective_args, **kwargs)
# TODO(mdan): Rename: to_ops?
-# TODO(mdan): Looki into overloading as function and decorator, like tfe.defun.
+# TODO(mdan): Look into overloading as function and decorator, like tfe.defun?
# TODO(mdan): Remove partial_types.
def to_graph(e,
recursive=True,
verbose=False,
arg_values=None,
arg_types=None,
- partial_types=None):
+ partial_types=None,
+ strip_decorators=None):
"""Converts a Python entity into equivalent code that uses TensorFlow ops.
Supported Python entities include:
@@ -234,6 +277,8 @@ def to_graph(e,
arg_types: Optional[Dict[Text, Type]], type hints for symbols including
function arguments.
partial_types: Set[Type], reserved for internal use.
+ strip_decorators: Tuple[Callable], same as
+ ConversionOptions.strip_decorators.
Returns:
Union[Callable, Type], the converted entity, which is the same kind as e
@@ -243,9 +288,13 @@ def to_graph(e,
Raises:
ValueError: If the entity could not be converted.
"""
+ if strip_decorators is None:
+ strip_decorators = ()
+ strip_decorators += (convert, do_not_convert, converted_call)
+
program_ctx = converter.ProgramContext(
recursive=recursive,
- autograph_decorators=(convert, do_not_convert, converted_call),
+ autograph_decorators=strip_decorators,
partial_types=partial_types,
autograph_module=tf_inspect.getmodule(to_graph),
uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
@@ -253,8 +302,9 @@ def to_graph(e,
arg_types)
nodes = []
- for dep in reversed(tuple(program_ctx.dependency_cache.values())):
- nodes.extend(dep)
+ for dep in reversed(program_ctx.conversion_order):
+ nodes.extend(program_ctx.dependency_cache[dep])
+
compiled_module, compiled_src = compiler.ast_to_object(
nodes,
source_prefix=program_ctx.required_imports,
@@ -322,7 +372,7 @@ def to_code(e,
conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)
code = '\n'.join(
- compiler.ast_to_source(dep, indentation)
- for dep in reversed(tuple(program_ctx.dependency_cache.values())))
+ compiler.ast_to_source(program_ctx.dependency_cache[dep], indentation)
+ for dep in reversed(program_ctx.conversion_order))
return program_ctx.required_imports + '\n\n' + code
diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py
index 54e12f0223..8ce5022c0a 100644
--- a/tensorflow/python/autograph/impl/api_test.py
+++ b/tensorflow/python/autograph/impl/api_test.py
@@ -32,7 +32,6 @@ from tensorflow.python.util import tf_inspect
tf = utils.fake_tf()
-
class ApiTest(test.TestCase):
def setUp(self):
@@ -56,7 +55,7 @@ class ApiTest(test.TestCase):
return x
tc = TestClass()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
@@ -76,7 +75,7 @@ class ApiTest(test.TestCase):
return x
tc = TestClass()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
@@ -97,7 +96,7 @@ class ApiTest(test.TestCase):
return x
tc = TestClass()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
@@ -123,7 +122,7 @@ class ApiTest(test.TestCase):
return x
tc = TestClass()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
@@ -146,7 +145,7 @@ class ApiTest(test.TestCase):
return x
tc = TestClass()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
@@ -180,19 +179,20 @@ class ApiTest(test.TestCase):
@api.convert(recursive=True)
def test_method(self, x, s, a):
while tf.reduce_sum(x) > s:
- x //= api.converted_call(self.called_member, False, False, False, {},
- self, a)
+ x //= api.converted_call(
+ self.called_member,
+ api.ConversionOptions.new(), self, a)
return x
tc = TestClass()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], sess.run(x).tolist())
def test_converted_call_builtin(self):
- x = api.converted_call(range, False, False, False, {}, 3)
+ x = api.converted_call(range, api.ConversionOptions.new(), 3)
self.assertEqual((0, 1, 2), tuple(x))
def test_converted_call_function(self):
@@ -202,8 +202,8 @@ class ApiTest(test.TestCase):
return -x
return x
- with self.test_session() as sess:
- x = api.converted_call(test_fn, False, False, False, {},
+ with self.cached_session() as sess:
+ x = api.converted_call(test_fn, api.ConversionOptions.new(),
constant_op.constant(-1))
self.assertEqual(1, sess.run(x))
@@ -219,9 +219,9 @@ class ApiTest(test.TestCase):
return -self.x
return self.x
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tc = TestClass(constant_op.constant(-1))
- x = api.converted_call(tc.test_method, False, False, False, {}, tc)
+ x = api.converted_call(tc.test_method, api.ConversionOptions.new(), tc)
self.assertEqual(1, sess.run(x))
def test_converted_call_method_by_class(self):
@@ -236,9 +236,11 @@ class ApiTest(test.TestCase):
return -self.x
return self.x
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tc = TestClass(constant_op.constant(-1))
- x = api.converted_call(TestClass.test_method, False, False, False, {}, tc)
+ x = api.converted_call(
+ TestClass.test_method,
+ api.ConversionOptions.new(), tc)
self.assertEqual(1, sess.run(x))
def test_converted_call_callable_object(self):
@@ -253,9 +255,9 @@ class ApiTest(test.TestCase):
return -self.x
return self.x
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tc = TestClass(constant_op.constant(-1))
- x = api.converted_call(tc, False, False, False, {})
+ x = api.converted_call(tc, api.ConversionOptions.new())
self.assertEqual(1, sess.run(x))
def test_converted_call_constructor(self):
@@ -270,8 +272,8 @@ class ApiTest(test.TestCase):
return -self.x
return self.x
- with self.test_session() as sess:
- tc = api.converted_call(TestClass, False, False, False, {},
+ with self.cached_session() as sess:
+ tc = api.converted_call(TestClass, api.ConversionOptions.new(),
constant_op.constant(-1))
# tc is now a converted object.
x = tc.test_method()
@@ -282,13 +284,13 @@ class ApiTest(test.TestCase):
def f(x):
return x == 0
- with self.test_session() as sess:
- x = api.converted_call(f, False, False, False, {},
+ with self.cached_session() as sess:
+ x = api.converted_call(f, api.ConversionOptions.new(),
constant_op.constant(0))
self.assertTrue(sess.run(x))
converted_f = api.to_graph(f)
- x = api.converted_call(converted_f, False, False, False, {},
+ x = api.converted_call(converted_f, api.ConversionOptions.new(),
constant_op.constant(0))
self.assertTrue(sess.run(x))
@@ -301,7 +303,7 @@ class ApiTest(test.TestCase):
compiled_fn = api.to_graph(test_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = compiled_fn(constant_op.constant([4, 8]), 4)
self.assertListEqual([1, 2], sess.run(x).tolist())
diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py
index 928ff9e7ea..a0d13c82a8 100644
--- a/tensorflow/python/autograph/impl/conversion.py
+++ b/tensorflow/python/autograph/impl/conversion.py
@@ -255,6 +255,7 @@ def _add_self_references(namespace, autograph_module):
# internal modules.
ag_internal = imp.new_module('autograph')
ag_internal.converted_call = autograph_module.converted_call
+ ag_internal.ConversionOptions = autograph_module.ConversionOptions
ag_internal.utils = utils
ag_internal.rewrite_graph_construction_error = (
errors.rewrite_graph_construction_error)
diff --git a/tensorflow/python/autograph/lang/special_functions_test.py b/tensorflow/python/autograph/lang/special_functions_test.py
index 1f1cec18f7..545dd11729 100644
--- a/tensorflow/python/autograph/lang/special_functions_test.py
+++ b/tensorflow/python/autograph/lang/special_functions_test.py
@@ -33,7 +33,7 @@ class SpecialFunctionsTest(test.TestCase):
l = special_functions.tensor_list(elements)
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def test_tensor_list_array_from_elements(self):
@@ -41,7 +41,7 @@ class SpecialFunctionsTest(test.TestCase):
l = special_functions.tensor_list(elements, use_tensor_array=True)
sl = l.stack()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def test_stack(self):
diff --git a/tensorflow/python/autograph/operators/py_builtins.py b/tensorflow/python/autograph/operators/py_builtins.py
index 1d37ae72d3..91a2a22cc2 100644
--- a/tensorflow/python/autograph/operators/py_builtins.py
+++ b/tensorflow/python/autograph/operators/py_builtins.py
@@ -193,11 +193,18 @@ def range_(start_or_stop, stop=UNDEFINED, step=UNDEFINED):
def _tf_range(start_or_stop, stop, step):
+ # Note: for static inputs (e.g. constants), tf.range errors out at graph
+ # construction time, instead of returning an empty tensor. Preventing the
+ # graph construction error aligns the semantics with Python.
+
# TODO(mdan): We should optimize this when a full tensor is not required.
if step is not UNDEFINED:
+ # TODO(mdan): Add argument coercion similar to other cases.
return math_ops.range(start_or_stop, stop, step)
if stop is not UNDEFINED:
+ stop = math_ops.maximum(start_or_stop, stop)
return math_ops.range(start_or_stop, stop)
+ start_or_stop = math_ops.maximum(start_or_stop, 0)
return math_ops.range(start_or_stop)
diff --git a/tensorflow/python/autograph/operators/py_builtins_test.py b/tensorflow/python/autograph/operators/py_builtins_test.py
index a021263ffa..c94a918d5a 100644
--- a/tensorflow/python/autograph/operators/py_builtins_test.py
+++ b/tensorflow/python/autograph/operators/py_builtins_test.py
@@ -36,7 +36,7 @@ class PyBuiltinsTest(test.TestCase):
def test_abs(self):
self.assertEqual(py_builtins.abs_(-1), 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t = py_builtins.abs_(constant_op.constant(-1))
self.assertEqual(sess.run(t), 1)
t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
@@ -45,7 +45,7 @@ class PyBuiltinsTest(test.TestCase):
def test_float(self):
self.assertEqual(py_builtins.float_(10), 10.0)
self.assertEqual(py_builtins.float_('10.0'), 10.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
self.assertEqual(sess.run(t), 1.0)
st = py_builtins.float_(constant_op.constant('1.0'))
@@ -54,7 +54,7 @@ class PyBuiltinsTest(test.TestCase):
def test_int(self):
self.assertEqual(py_builtins.int_(10.0), 10)
self.assertEqual(py_builtins.int_('11', 2), 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
self.assertEqual(sess.run(t), 1)
st = py_builtins.int_(constant_op.constant('1'))
@@ -69,7 +69,7 @@ class PyBuiltinsTest(test.TestCase):
def test_len(self):
self.assertEqual(py_builtins.len_([1, 2, 3]), 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
self.assertEqual(t, 3)
ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
@@ -82,7 +82,7 @@ class PyBuiltinsTest(test.TestCase):
py_builtins.len_(constant_op.constant(1))
def test_len_dynamic_shape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
t = py_builtins.len_(p)
self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3)
@@ -95,7 +95,7 @@ class PyBuiltinsTest(test.TestCase):
try:
out_capturer = six.StringIO()
sys.stdout = out_capturer
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(py_builtins.print_(constant_op.constant('test message'), 1))
self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
finally:
@@ -105,7 +105,7 @@ class PyBuiltinsTest(test.TestCase):
try:
out_capturer = six.StringIO()
sys.stdout = out_capturer
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
py_builtins.print_(constant_op.constant('test message'), [1, 2]))
self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
@@ -118,7 +118,7 @@ class PyBuiltinsTest(test.TestCase):
self.assertListEqual(list(py_builtins.range_(2, 0, -1)), [2, 1])
def test_range_tensor(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
r = py_builtins.range_(constant_op.constant(3))
self.assertAllEqual(sess.run(r), [0, 1, 2])
r = py_builtins.range_(1, constant_op.constant(3))
@@ -126,6 +126,13 @@ class PyBuiltinsTest(test.TestCase):
r = py_builtins.range_(2, 0, constant_op.constant(-1))
self.assertAllEqual(sess.run(r), [2, 1])
+ def test_range_tensor_empty_range(self):
+ with self.test_session() as sess:
+ r = py_builtins.range_(constant_op.constant(-3))
+ self.assertAllEqual(sess.run(r), [])
+ r = py_builtins.range_(5, constant_op.constant(2))
+ self.assertAllEqual(sess.run(r), [])
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/autograph/operators/slices_test.py b/tensorflow/python/autograph/operators/slices_test.py
index d8b8418750..9e4865b3c6 100644
--- a/tensorflow/python/autograph/operators/slices_test.py
+++ b/tensorflow/python/autograph/operators/slices_test.py
@@ -51,14 +51,14 @@ class SlicesTest(test.TestCase):
t = slices.get_item(initial_str, 1,
slices.GetItemOpts(element_dtype=initial_str.dtype))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(t), b'b')
initial_list_str = constant_op.constant(['abcd', 'bcde'])
t = slices.get_item(initial_list_str, 1,
slices.GetItemOpts(element_dtype=initial_str.dtype))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(t), b'bcde')
diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py
index 1433f9ac83..fca0eb62e4 100644
--- a/tensorflow/python/autograph/pyct/cfg.py
+++ b/tensorflow/python/autograph/pyct/cfg.py
@@ -27,6 +27,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import weakref
from enum import Enum
# pylint:disable=g-bad-import-order
@@ -61,7 +62,10 @@ class Node(object):
def freeze(self):
self.next = frozenset(self.next)
- self.prev = frozenset(self.prev)
+ # Assumption: All CFG nodes have identical life spans, because the graph
+ # owns them. Nodes should never be used outside the context of an existing
+ # graph.
+ self.prev = weakref.WeakSet(self.prev)
def __repr__(self):
if isinstance(self.ast_node, gast.FunctionDef):
@@ -256,7 +260,7 @@ class GraphBuilder(object):
"""Resets the state of this factory."""
self.head = None
self.errors = set()
- self.node_index = collections.OrderedDict()
+ self.node_index = {}
# TODO(mdan): Too many primitives. Use classes.
self.leaves = set()
@@ -309,7 +313,10 @@ class GraphBuilder(object):
"""Grows the graph by adding a CFG node following the current leaves."""
if ast_node is self.node_index:
raise ValueError('%s added twice' % ast_node)
- node = Node(next_=set(), prev=set(), ast_node=ast_node)
+ # Assumption: All CFG nodes have identical life spans, because the graph
+ # owns them. Nodes should never be used outside the context of an existing
+ # graph.
+ node = Node(next_=set(), prev=weakref.WeakSet(), ast_node=ast_node)
self.node_index[ast_node] = node
self.owners[node] = frozenset(self.active_stmts)
diff --git a/tensorflow/python/autograph/pyct/compiler.py b/tensorflow/python/autograph/pyct/compiler.py
index 9e1b6bdbe8..21281aeb56 100644
--- a/tensorflow/python/autograph/pyct/compiler.py
+++ b/tensorflow/python/autograph/pyct/compiler.py
@@ -57,8 +57,15 @@ def ast_to_source(node, indentation=' '):
# In some versions of Python, literals may appear as actual values. This
# ensures everything is string.
- code = map(str, generator.result)
- code = astor.source_repr.pretty_source(code).lstrip()
+ code = ''.join(map(str, generator.result))
+
+ # Strip leading blank lines.
+ code_lines = code.split('\n')
+ trimmed_code_lines = []
+ for l in code_lines:
+ if l.rstrip() or trimmed_code_lines:
+ trimmed_code_lines.append(l)
+ code = '\n'.join(trimmed_code_lines)
return code
@@ -108,7 +115,7 @@ def ast_to_object(nodes,
indices = (-1,)
if include_source_map:
- source_map = origin_info.source_map(nodes, source, f.name, indices)
+ source_map = origin_info.create_source_map(nodes, source, f.name, indices)
# TODO(mdan): Try flush() and delete=False instead.
if delete_on_exit:
diff --git a/tensorflow/python/autograph/pyct/origin_info.py b/tensorflow/python/autograph/pyct/origin_info.py
index 4c7c4165ef..102bd42c91 100644
--- a/tensorflow/python/autograph/pyct/origin_info.py
+++ b/tensorflow/python/autograph/pyct/origin_info.py
@@ -75,7 +75,7 @@ class OriginInfo(
# TODO(mdan): This source map should be a class - easier to refer to.
-def source_map(nodes, code, filename, indices_in_code):
+def create_source_map(nodes, code, filename, indices_in_code):
"""Creates a source map between an annotated AST and the code it compiles to.
Args:
diff --git a/tensorflow/python/autograph/pyct/origin_info_test.py b/tensorflow/python/autograph/pyct/origin_info_test.py
index 6b9c30dbd0..3b1d5f2040 100644
--- a/tensorflow/python/autograph/pyct/origin_info_test.py
+++ b/tensorflow/python/autograph/pyct/origin_info_test.py
@@ -27,49 +27,41 @@ from tensorflow.python.platform import test
class OriginInfoTest(test.TestCase):
- def test_source_map(self):
+ def test_create_source_map(self):
def test_fn(x):
- if x > 0:
- x += 1
- return x
-
- node, source = parser.parse_entity(test_fn)
+ return x + 1
+
+ node, _ = parser.parse_entity(test_fn)
+ fake_origin = origin_info.OriginInfo(
+ loc=origin_info.Location('fake_filename', 3, 7),
+ function_name='fake_function_name',
+ source_code_line='fake source line',
+ comment=None)
fn_node = node.body[0]
- origin_info.resolve(fn_node, source)
-
- # Insert a traced line.
- new_node = parser.parse_str('x = abs(x)').body[0]
- anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN)
- fn_node.body.insert(0, new_node)
+ anno.setanno(fn_node.body[0], anno.Basic.ORIGIN, fake_origin)
+ converted_code = compiler.ast_to_source(fn_node)
- # Insert an untraced line.
- fn_node.body.insert(0, parser.parse_str('x = 0').body[0])
+ source_map = origin_info.create_source_map(
+ fn_node, converted_code, 'test_filename', [0])
- modified_source = compiler.ast_to_source(fn_node)
+ loc = origin_info.LineLocation('test_filename', 2)
+ self.assertIn(loc, source_map)
+ self.assertIs(source_map[loc], fake_origin)
- source_map = origin_info.source_map(fn_node, modified_source,
- 'test_filename', [0])
+ def test_source_map_no_origin(self):
- loc = origin_info.LineLocation('test_filename', 1)
- origin = source_map[loc]
- self.assertEqual(origin.source_code_line, 'def test_fn(x):')
- self.assertEqual(origin.loc.lineno, 1)
+ def test_fn(x):
+ return x + 1
- # The untraced line, inserted second.
- loc = origin_info.LineLocation('test_filename', 2)
- self.assertFalse(loc in source_map)
+ node, _ = parser.parse_entity(test_fn)
+ fn_node = node.body[0]
+ converted_code = compiler.ast_to_source(fn_node)
- # The traced line, inserted first.
- loc = origin_info.LineLocation('test_filename', 3)
- origin = source_map[loc]
- self.assertEqual(origin.source_code_line, ' if x > 0:')
- self.assertEqual(origin.loc.lineno, 2)
+ source_map = origin_info.create_source_map(
+ fn_node, converted_code, 'test_filename', [0])
- loc = origin_info.LineLocation('test_filename', 4)
- origin = source_map[loc]
- self.assertEqual(origin.source_code_line, ' if x > 0:')
- self.assertEqual(origin.loc.lineno, 2)
+ self.assertEqual(len(source_map), 0)
def test_resolve(self):
@@ -79,6 +71,7 @@ class OriginInfoTest(test.TestCase):
node, source = parser.parse_entity(test_fn)
fn_node = node.body[0]
+
origin_info.resolve(fn_node, source)
origin = anno.getanno(fn_node, anno.Basic.ORIGIN)
diff --git a/tensorflow/python/autograph/pyct/parser.py b/tensorflow/python/autograph/pyct/parser.py
index 112ed46a1e..63686350d5 100644
--- a/tensorflow/python/autograph/pyct/parser.py
+++ b/tensorflow/python/autograph/pyct/parser.py
@@ -31,8 +31,21 @@ from tensorflow.python.util import tf_inspect
def parse_entity(entity):
"""Returns the AST of given entity."""
source = tf_inspect.getsource(entity)
+ # Comments and multiline strings can appear at arbitrary indentation levels,
+ # causing textwrap.dedent to not correctly dedent source code.
+ # TODO(b/115884650): Automatic handling of comments/multiline strings.
source = textwrap.dedent(source)
- return parse_str(source), source
+ try:
+ return parse_str(source), source
+ except IndentationError:
+ # Because we are parsing the source code of entities that have already
+ # successfully parsed once, any IndentationErrors are guaranteed to be
+ # caused by insufficient dedenting.
+ raise ValueError(
+ 'Failed to dedent prior to parsing source code. If you have comments '
+ 'or multiline strings in your code, try indenting them. '
+ 'Multiline strings can be rewritten using textwrap.dedent.\n'
+ 'Offending source code: \n %s' % source)
def parse_str(src):
diff --git a/tensorflow/python/autograph/pyct/parser_test.py b/tensorflow/python/autograph/pyct/parser_test.py
index d0b465eb73..d3a7b7a014 100644
--- a/tensorflow/python/autograph/pyct/parser_test.py
+++ b/tensorflow/python/autograph/pyct/parser_test.py
@@ -42,6 +42,22 @@ class ParserTest(test.TestCase):
"""))
self.assertEqual('f', mod.body[0].name)
+ def test_parse_comments(self):
+ def f():
+# unindented comment
+ pass
+ with self.assertRaises(ValueError):
+ parser.parse_entity(f)
+
+ def test_parse_multiline_strings(self):
+ def f():
+ print("""
+some
+multiline
+string""")
+ with self.assertRaises(ValueError):
+ parser.parse_entity(f)
+
def test_parse_expression(self):
node = parser.parse_expression('a.b')
self.assertEqual('a', node.value.id)
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py
index 9cb5991322..086eda7574 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py
@@ -22,6 +22,7 @@ from __future__ import division
from __future__ import print_function
import copy
+import weakref
import gast
@@ -126,7 +127,10 @@ class Scope(object):
self.parent.mark_read(name)
def mark_param(self, name, owner):
- self.params[name] = owner
+ # Assumption: all AST nodes have the same life span. This lets us use
+ # a weak reference to mark the connection between a symbol node and the
+ # function node whose argument that symbol is.
+ self.params[name] = weakref.ref(owner)
def mark_creation(self, name, writes_create_symbol=False):
"""Mark a qualified name as created."""
diff --git a/tensorflow/python/autograph/pyct/static_analysis/live_values.py b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
index 48b442f3bd..36b9e7074d 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/live_values.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
@@ -29,10 +29,11 @@ from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
+
# TODO(aqj): Do we need this? Do other builtins fail in similar ways
# See b/114389775 for a related bug in pyct
# These symbols are legal in Python, but don't appear in the namespace.
-_special_symbols = {'range': range}
+_SPECIAL_SYMBOLS = {'range': range, 'print': print}
class LiveValueResolver(transformer.Base):
@@ -71,8 +72,10 @@ class LiveValueResolver(transformer.Base):
# If the symbol value is for example a primitive, then it will not
# have a name.
pass
- elif node.id in _special_symbols:
- anno.setanno(node, 'live_val', _special_symbols[node.id])
+ elif node.id in _SPECIAL_SYMBOLS:
+ # Note: if the user redefined any of these symbols, then they would
+ # be visible in the namespace and we would never reach this branch.
+ anno.setanno(node, 'live_val', _SPECIAL_SYMBOLS[node.id])
else:
pass
# TODO(mdan): Should we raise an error here?
@@ -86,7 +89,8 @@ class LiveValueResolver(transformer.Base):
if has_single_def:
def_, = defs
- if def_.param_of is self.enclosing_entities[0]:
+ # Note: param_of is a weakref.
+ if def_.param_of and def_.param_of() is self.enclosing_entities[0]:
if node.id in self.entity_info.arg_values:
obj = self.entity_info.arg_values[node.id]
anno.setanno(node, 'live_val', obj)
diff --git a/tensorflow/python/autograph/pyct/templates.py b/tensorflow/python/autograph/pyct/templates.py
index 68c2a35fac..1af8fca599 100644
--- a/tensorflow/python/autograph/pyct/templates.py
+++ b/tensorflow/python/autograph/pyct/templates.py
@@ -109,6 +109,7 @@ class ReplaceTransformer(gast.NodeTransformer):
if not node.ctx:
raise ValueError('node %s is missing ctx value' % node)
+ # TODO(mdan): Rewrite _check and _set using a separate transformer.
def _check_inner_children_have_context(self, node):
if isinstance(node, gast.Attribute):
self._check_inner_children_have_context(node.value)
@@ -122,6 +123,8 @@ class ReplaceTransformer(gast.NodeTransformer):
self._check_inner_children_have_context(e)
for e in node.values:
self._check_inner_children_have_context(e)
+ elif isinstance(node, gast.Index):
+ self._check_inner_children_have_context(node.value)
elif isinstance(node, gast.Subscript):
self._check_inner_children_have_context(node.value)
self._check_inner_children_have_context(node.slice)
@@ -131,6 +134,11 @@ class ReplaceTransformer(gast.NodeTransformer):
self._check_inner_children_have_context(node.upper)
if node.step:
self._check_inner_children_have_context(node.step)
+ elif isinstance(node, gast.BinOp):
+ self._check_inner_children_have_context(node.left)
+ self._check_inner_children_have_context(node.right)
+ elif isinstance(node, gast.UnaryOp):
+ self._check_inner_children_have_context(node.operand)
elif isinstance(node, gast.Name):
self._check_has_context(node)
elif isinstance(node, (gast.Str, gast.Num)):
@@ -166,6 +174,11 @@ class ReplaceTransformer(gast.NodeTransformer):
elif isinstance(node, gast.Subscript):
self._set_inner_child_context(node.value, ctx)
self._check_inner_children_have_context(node.slice)
+ elif isinstance(node, gast.BinOp):
+ self._check_inner_children_have_context(node.left)
+ self._check_inner_children_have_context(node.right)
+ elif isinstance(node, gast.UnaryOp):
+ self._check_inner_children_have_context(node.operand)
elif isinstance(node, (gast.Str, gast.Num)):
pass
else:
diff --git a/tensorflow/python/autograph/pyct/templates_test.py b/tensorflow/python/autograph/pyct/templates_test.py
index 66268cfaad..3032241846 100644
--- a/tensorflow/python/autograph/pyct/templates_test.py
+++ b/tensorflow/python/autograph/pyct/templates_test.py
@@ -132,6 +132,18 @@ class TemplatesTest(test.TestCase):
self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
+ def test_replace_expression_context(self):
+ template = """
+ def test_fn(foo):
+ foo
+ """
+
+ node = templates.replace(
+ template, foo=parser.parse_expression('a + 2 * b / -c'))[0]
+ self.assertIsInstance(node.body[0].ctx, gast.Load)
+ self.assertIsInstance(node.body[0].left.ctx, gast.Load)
+ self.assertIsInstance(node.body[0].right.left.right.ctx, gast.Load)
+
def test_replace_complex_context(self):
template = """
def test_fn(foo):
@@ -146,6 +158,18 @@ class TemplatesTest(test.TestCase):
self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load)
self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load)
+ def test_replace_index(self):
+ template = """
+ def test_fn(foo):
+ foo = 0
+ """
+
+ node = templates.replace(
+ template, foo=parser.parse_expression('foo(a[b]).bar'))[0]
+ function_call_arg = node.body[0].targets[0].value.args[0]
+ self.assertIsInstance(function_call_arg.ctx, gast.Load)
+ self.assertIsInstance(function_call_arg.slice.value.ctx, gast.Load)
+
def test_replace_call_keyword(self):
template = """
def test_fn():
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index ae0ad27f15..c963cfd334 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -178,16 +178,30 @@ def register_session_run_conversion_functions(
feed_function_for_partial_run: A callable for specifying tensor values to
feed when setting up a partial run, which takes a `tensor_type` type
object as input, and returns a list of Tensors.
+
+ Raises:
+ ValueError: If `tensor_type` has already been registered.
"""
for conversion_function in _REGISTERED_EXPANSIONS:
if issubclass(conversion_function[0], tensor_type):
- raise ValueError('%s has already been registered so ignore it.',
+ raise ValueError('%s has already been registered so ignore it.' %
tensor_type)
- return
+
_REGISTERED_EXPANSIONS.insert(0, (tensor_type, fetch_function, feed_function,
feed_function_for_partial_run))
+def _is_attrs_instance(obj):
+ """Returns True if the given obj is an instance of attrs-decorated class."""
+ return getattr(obj.__class__, '__attrs_attrs__', None) is not None
+
+
+def _get_attrs_values(obj):
+ """Returns the list of values from an attrs instance."""
+ attrs = getattr(obj.__class__, '__attrs_attrs__')
+ return [getattr(obj, a.name) for a in attrs]
+
+
class _FetchMapper(object):
"""Definition of the interface provided by fetch mappers.
@@ -247,6 +261,8 @@ class _FetchMapper(object):
return _ListFetchMapper(fetch)
elif isinstance(fetch, collections.Mapping):
return _DictFetchMapper(fetch)
+ elif _is_attrs_instance(fetch):
+ return _AttrsFetchMapper(fetch)
else:
# Look for a handler in the registered expansions.
for tensor_type, fetch_fn, _, _ in _REGISTERED_EXPANSIONS:
@@ -398,6 +414,32 @@ class _DictFetchMapper(_FetchMapper):
return results
+class _AttrsFetchMapper(_FetchMapper):
+ """Fetch mapper for attrs decorated classes."""
+
+ def __init__(self, fetches):
+ """Creates a _AttrsFetchMapper.
+
+ Args:
+ fetches: An instance of an attrs decorated class.
+ """
+ values = _get_attrs_values(fetches)
+ self._fetch_type = type(fetches)
+ self._mappers = [
+ _FetchMapper.for_fetch(fetch) for fetch in values
+ ]
+ self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
+
+ def unique_fetches(self):
+ return self._unique_fetches
+
+ def build_results(self, values):
+ results = []
+ for m, vi in zip(self._mappers, self._value_indices):
+ results.append(m.build_results([values[j] for j in vi]))
+ return self._fetch_type(*results)
+
+
class _FetchHandler(object):
"""Handler for structured fetches.
diff --git a/tensorflow/python/client/session_ref.cc b/tensorflow/python/client/session_ref.cc
new file mode 100644
index 0000000000..4d361612b7
--- /dev/null
+++ b/tensorflow/python/client/session_ref.cc
@@ -0,0 +1,525 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/python/client/session_ref.h"
+
+#include <stdlib.h>
+#include <memory>
+#include <utility>
+
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/protobuf/master.pb.h"
+#include "tensorflow/core/protobuf/named_tensor.pb.h"
+#include "tensorflow/core/protobuf/replay_log.pb.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Scope helper to track active calls and manage session lifetime.
+// SessionRef blocks closing until all active calls complete or are cancelled.
+struct RunCounter {
+ std::shared_ptr<Session> session;
+ uint64* value;
+ mutex* m;
+ condition_variable* cv;
+
+ explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m,
+ condition_variable* cv)
+ : session(std::move(s)), value(v), m(m), cv(cv) {
+ mutex_lock l(*m);
+ ++*value;
+ }
+
+ ~RunCounter() {
+ mutex_lock l(*m);
+ if (--*value == 0) {
+ cv->notify_all();
+ }
+ }
+};
+
+std::string SessionToHandle(Session* session) {
+ return strings::Printf("%llu", reinterpret_cast<uint64>(session));
+}
+
+// The Session interface has many methods of the form:
+//
+// X(a, b);
+// X(RunOptions, a, b);
+//
+// Not all sessions support the second case (with an empty RunOptions()).
+// We use this variable as a sentinel to dispatch to the correct call.
+RunOptions* kEmptyRunOptions() {
+ static RunOptions* options = new RunOptions();
+ return options;
+}
+
+} // namespace
+
+// Run the given session operation, recording start and end timestamps.
+// If the operation returns a bad status, return after flushing the current
+// log request. This should be run _after_ all request information has been
+// added to the current op.
+#define RUN_WITH_TIMESTAMP(OpName, ...) \
+ op.set_start_time_us(Env::Default()->NowMicros()); \
+ Status status = session->OpName(__VA_ARGS__); \
+ op.set_end_time_us(Env::Default()->NowMicros()); \
+ if (!status.ok()) { \
+ Flush(op).IgnoreError(); \
+ return status; \
+ }
+
+// Records requests (and optionally responses) performed against a session.
+// The resulting replay log can be used with the `tf_replay` tool to replicate
+// the operations against a simulated environment, without requiring the
+// original code or cluster setup.
+//
+// Session logging by setting the TF_REPLAY_LOG_FILE environment variable.
+class SessionLogger {
+ public:
+ SessionLogger() {
+ std::string log_name = getenv("TF_REPLAY_LOG_FILE");
+ LOG(INFO) << "Constructing new session logger for " << log_name;
+ TF_CHECK_OK(
+ Env::Default()->RecursivelyCreateDir(string(io::Dirname(log_name))));
+ Env::Default()->DeleteFile(log_name).IgnoreError();
+
+ TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_));
+ log_writer_ = absl::make_unique<io::RecordWriter>(log_file_.get());
+ }
+
+ ~SessionLogger() {
+ log_writer_->Close().IgnoreError();
+ log_writer_.release();
+ log_file_->Close().IgnoreError();
+ }
+
+ Status RecordNewSession(Session* session) {
+ LOG(INFO) << "New session discovered. Capturing devices...";
+ ReplayOp op;
+ NewReplaySession* req = op.mutable_new_replay_session();
+
+ std::vector<DeviceAttributes> devices;
+ Status status = session->ListDevices(&devices);
+ if (status.ok()) {
+ LOG(INFO) << "Found: " << devices.size() << " devices.";
+ for (const DeviceAttributes& dev : devices) {
+ *req->mutable_devices()->add_local_device() = dev;
+ }
+ } else {
+ LOG(WARNING) << "Failed to list devices on session. Continuing.";
+ }
+
+ req->set_session_handle(SessionToHandle(session));
+ return Flush(op);
+ }
+
+ Status RecordRun(Session* session,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs) {
+ return RecordRun(session, *kEmptyRunOptions(), inputs, output_tensor_names,
+ target_node_names, outputs, nullptr);
+ }
+
+ Status RecordRun(Session* session, const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs, RunMetadata* run_metadata) {
+ ReplayOp op;
+ RunStepRequest* req = op.mutable_run_step();
+ RunStepResponse* resp = op.mutable_run_step_response();
+
+ req->set_session_handle(SessionToHandle(session));
+ *req->mutable_options() = run_options;
+
+ for (const auto& it : inputs) {
+ NamedTensorProto* feed = req->add_feed();
+ feed->set_name(it.first);
+ it.second.AsProtoField(feed->mutable_tensor());
+ }
+
+ // Build an index from fetch tensor name to first index in
+ // output_tensor_names.
+ std::unordered_map<string, int> output_name_to_offset;
+ for (int i = 0; i < output_tensor_names.size(); ++i) {
+ const string& name = output_tensor_names[i];
+ if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
+ req->add_fetch(name);
+ }
+ }
+ for (const string& target : target_node_names) {
+ req->add_target(target);
+ }
+
+ if (&run_options == kEmptyRunOptions()) {
+ RUN_WITH_TIMESTAMP(Run, inputs, output_tensor_names, target_node_names,
+ outputs);
+ } else {
+ RUN_WITH_TIMESTAMP(Run, run_options, inputs, output_tensor_names,
+ target_node_names, outputs, run_metadata);
+ }
+
+ for (size_t i = 0; i < outputs->size(); ++i) {
+ const Tensor& tensor = (*outputs)[i];
+ NamedTensorProto* tproto = resp->add_tensor();
+ tensor.AsProtoField(tproto->mutable_tensor());
+ tproto->set_name(output_tensor_names[i]);
+ }
+
+ if (run_metadata) {
+ *resp->mutable_metadata() = *run_metadata;
+ }
+
+ return Flush(op);
+ }
+
+ Status RecordCreate(Session* session, const GraphDef& graph) {
+ return RecordCreate(session, *kEmptyRunOptions(), graph);
+ }
+
+ // N.B. RunOptions is not stored (it has no entry in CreateRequest)
+ Status RecordCreate(Session* session, const RunOptions& run_options,
+ const GraphDef& graph) {
+ ReplayOp op;
+ CreateSessionRequest* req = op.mutable_create_session();
+ *req->mutable_graph_def() = graph;
+
+ CreateSessionResponse* resp = op.mutable_create_session_response();
+ if (&run_options == kEmptyRunOptions()) {
+ RUN_WITH_TIMESTAMP(Create, graph);
+ } else {
+ RUN_WITH_TIMESTAMP(Create, run_options, graph);
+ }
+ resp->set_session_handle(SessionToHandle(session));
+ return Flush(op);
+ }
+
+ Status RecordExtend(Session* session, const GraphDef& graph) {
+ return RecordExtend(session, *kEmptyRunOptions(), graph);
+ }
+
+ // N.B. RunOptions is not stored (it has no entry in ExtendRequest)
+ Status RecordExtend(Session* session, const RunOptions& run_options,
+ const GraphDef& graph) {
+ ReplayOp op;
+ ExtendSessionRequest* req = op.mutable_extend_session();
+ op.mutable_extend_session_response();
+ req->set_session_handle(SessionToHandle(session));
+ *req->mutable_graph_def() = graph;
+ if (&run_options == kEmptyRunOptions()) {
+ RUN_WITH_TIMESTAMP(Extend, graph);
+ } else {
+ RUN_WITH_TIMESTAMP(Extend, run_options, graph);
+ }
+
+ return Flush(op);
+ }
+
+ Status RecordClose(Session* session) {
+ return RecordClose(session, *kEmptyRunOptions());
+ }
+
+ // N.B. RunOptions is not stored (it has no entry in CloseRequest)
+ Status RecordClose(Session* session, const RunOptions& run_options) {
+ ReplayOp op;
+ CloseSessionRequest* req = op.mutable_close_session();
+ req->set_session_handle(SessionToHandle(session));
+ op.mutable_close_session_response();
+ if (&run_options == kEmptyRunOptions()) {
+ RUN_WITH_TIMESTAMP(Close);
+ } else {
+ RUN_WITH_TIMESTAMP(Close, run_options);
+ }
+ return Flush(op);
+ }
+
+ Status RecordListDevices(Session* session,
+ std::vector<DeviceAttributes>* response) {
+ ReplayOp op;
+ ListDevicesRequest* req = op.mutable_list_devices();
+ ListDevicesResponse* resp = op.mutable_list_devices_response();
+ req->set_session_handle(SessionToHandle(session));
+ RUN_WITH_TIMESTAMP(ListDevices, response);
+
+ // TODO(power) -- local vs remote device distinction is lost here!
+ *resp->mutable_local_device() = {response->begin(), response->end()};
+ return Flush(op);
+ }
+
+ Status RecordPRunSetup(Session* session,
+ const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) {
+ ReplayOp op;
+ PartialRunSetupRequest* req = op.mutable_partial_run_setup();
+ req->set_session_handle(SessionToHandle(session));
+ for (auto& input : input_names) {
+ req->add_feed(input);
+ }
+ for (auto& output : output_names) {
+ req->add_fetch(output);
+ }
+ for (auto& target : target_nodes) {
+ req->add_target(target);
+ }
+ RUN_WITH_TIMESTAMP(PRunSetup, input_names, output_names, target_nodes,
+ handle);
+ op.mutable_partial_run_setup_response()->set_partial_run_handle(*handle);
+ return Flush(op);
+ }
+
+ Status RecordPRun(Session* session, const string& handle,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) {
+ ReplayOp op;
+ RunStepRequest* req = op.mutable_run_step();
+ RunStepResponse* resp = op.mutable_run_step_response();
+ req->set_session_handle(SessionToHandle(session));
+
+ // Mark this step as a partial run for replay.
+ req->set_partial_run_handle(handle);
+ for (auto& input : inputs) {
+ auto* feed = req->add_feed();
+ feed->set_name(input.first);
+ input.second.AsProtoField(feed->mutable_tensor());
+ }
+
+ for (auto& output : output_names) {
+ req->add_fetch(output);
+ }
+
+ RUN_WITH_TIMESTAMP(PRun, handle, inputs, output_names, outputs);
+
+ for (size_t i = 0; i < outputs->size(); ++i) {
+ const Tensor& tensor = (*outputs)[i];
+ NamedTensorProto* tproto = resp->add_tensor();
+ tensor.AsProtoField(tproto->mutable_tensor());
+ tproto->set_name(output_names[i]);
+ }
+
+ return Flush(op);
+ }
+
+ Status RecordMakeCallable(Session* session,
+ const CallableOptions& callable_options,
+ Session::CallableHandle* handle) {
+ ReplayOp op;
+ MakeCallableRequest* req = op.mutable_make_callable();
+ req->set_session_handle(SessionToHandle(session));
+ *req->mutable_options() = callable_options;
+
+ RUN_WITH_TIMESTAMP(MakeCallable, callable_options, handle);
+
+ MakeCallableResponse* resp = op.mutable_make_callable_response();
+ resp->set_handle(*handle);
+
+ return Flush(op);
+ }
+
+ Status RecordRunCallable(Session* session, Session::CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata) {
+ ReplayOp op;
+ RunCallableRequest* req = op.mutable_run_callable();
+ req->set_session_handle(SessionToHandle(session));
+ req->set_handle(handle);
+ for (auto& tensor : feed_tensors) {
+ tensor.AsProtoField(req->add_feed());
+ }
+ RUN_WITH_TIMESTAMP(RunCallable, handle, feed_tensors, fetch_tensors,
+ run_metadata);
+
+ RunCallableResponse* resp = op.mutable_run_callable_response();
+ if (run_metadata) {
+ *resp->mutable_metadata() = *run_metadata;
+ }
+ for (const Tensor& tensor : *fetch_tensors) {
+ tensor.AsProtoTensorContent(resp->add_fetch());
+ }
+ return Flush(op);
+ }
+
+ Status RecordReleaseCallable(Session* session,
+ Session::CallableHandle handle) {
+ ReplayOp op;
+ ReleaseCallableRequest* req = op.mutable_release_callable();
+ req->set_session_handle(SessionToHandle(session));
+ req->set_handle(handle);
+ RUN_WITH_TIMESTAMP(ReleaseCallable, handle);
+ return Flush(op);
+ }
+
+ private:
+ Status Flush(const ReplayOp& op) {
+ mutex_lock l(log_mutex_);
+
+ string buf;
+ op.SerializeToString(&buf);
+ TF_RETURN_IF_ERROR(log_writer_->WriteRecord(buf));
+
+ // TODO(b/116624106): Not all file-systems respect calls to `Sync()`
+ return log_file_->Sync();
+ }
+
+ std::unique_ptr<WritableFile> log_file_;
+ std::unique_ptr<io::RecordWriter> log_writer_;
+ mutex log_mutex_;
+};
+
+static SessionLogger* global_session_logger() {
+ static SessionLogger* logger = new SessionLogger();
+ return logger;
+}
+
+SessionRef::SessionRef(Session* session) : session_(session) {
+ if (getenv("TF_REPLAY_LOG_FILE") != nullptr) {
+ logger_ = global_session_logger();
+ logger_->RecordNewSession(this->session_.get()).IgnoreError();
+ } else {
+ logger_ = nullptr;
+ }
+}
+
+SessionRef::~SessionRef() = default;
+
+Status SessionRef::CheckNotClosed() {
+ mutex_lock l(run_lock_);
+ if (session_ == nullptr) return errors::Cancelled("Session has been closed.");
+ return ::tensorflow::Status::OK();
+}
+
+// If logging is active, log the start and end time of the operation along with
+// the request and response.
+#define LOG_AND_RUN_OPERATION(OpName, ...) \
+ TF_RETURN_IF_ERROR(CheckNotClosed()); \
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); \
+ if (!logger_) { \
+ return rc.session->OpName(__VA_ARGS__); \
+ } \
+ return logger_->Record##OpName(rc.session.get(), __VA_ARGS__);
+
+Status SessionRef::Run(const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs,
+ RunMetadata* run_metadata) {
+ LOG_AND_RUN_OPERATION(Run, run_options, inputs, output_tensor_names,
+ target_node_names, outputs, run_metadata);
+}
+
+Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs) {
+ LOG_AND_RUN_OPERATION(Run, inputs, output_tensor_names, target_node_names,
+ outputs);
+}
+
+Status SessionRef::Create(const GraphDef& graph) {
+ LOG_AND_RUN_OPERATION(Create, graph);
+}
+
+Status SessionRef::Create(const RunOptions& run_options,
+ const GraphDef& graph) {
+ LOG_AND_RUN_OPERATION(Create, run_options, graph);
+}
+
+Status SessionRef::Extend(const RunOptions& run_options,
+ const GraphDef& graph) {
+ LOG_AND_RUN_OPERATION(Extend, run_options, graph);
+}
+
+Status SessionRef::Extend(const GraphDef& graph) {
+ LOG_AND_RUN_OPERATION(Extend, graph);
+}
+
+Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) {
+ LOG_AND_RUN_OPERATION(ListDevices, response);
+}
+
+Status SessionRef::PRunSetup(const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) {
+ LOG_AND_RUN_OPERATION(PRunSetup, input_names, output_names, target_nodes,
+ handle);
+}
+
+Status SessionRef::PRun(const string& handle,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) {
+ LOG_AND_RUN_OPERATION(PRun, handle, inputs, output_names, outputs);
+}
+
+Status SessionRef::MakeCallable(const CallableOptions& callable_options,
+ CallableHandle* out_handle) {
+ LOG_AND_RUN_OPERATION(MakeCallable, callable_options, out_handle);
+}
+
+Status SessionRef::RunCallable(CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata) {
+ LOG_AND_RUN_OPERATION(RunCallable, handle, feed_tensors, fetch_tensors,
+ run_metadata);
+}
+
+Status SessionRef::ReleaseCallable(CallableHandle handle) {
+ LOG_AND_RUN_OPERATION(ReleaseCallable, handle);
+}
+
+Status SessionRef::Close(const RunOptions& run_options) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ mutex_lock l(run_lock_);
+ Status status;
+ if (logger_) {
+ status = logger_->RecordClose(session_.get(), run_options);
+ } else {
+ status = session_->Close(run_options);
+ }
+ session_.reset();
+ while (run_count_ > 0) {
+ run_finished_.wait(l);
+ }
+ return status;
+}
+
+Status SessionRef::Close() {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ mutex_lock l(run_lock_);
+ Status status;
+ if (logger_) {
+ status = logger_->RecordClose(session_.get());
+ } else {
+ status = session_->Close();
+ }
+ session_.reset();
+ while (run_count_ > 0) {
+ run_finished_.wait(l);
+ }
+ return status;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/session_ref.h b/tensorflow/python/client/session_ref.h
index 9459e7edbe..b0fb12b189 100644
--- a/tensorflow/core/common_runtime/session_ref.h
+++ b/tensorflow/python/client/session_ref.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
-#define TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
+#ifndef TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_
+#define TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_
#include <memory>
@@ -22,6 +22,8 @@ limitations under the License.
namespace tensorflow {
+class SessionLogger;
+
// A `SessionRef` manages the lifetime of a wrapped `Session` pointer.
//
// SessionRef blocks the return of Close() until all pending operations have
@@ -29,8 +31,8 @@ namespace tensorflow {
// subsequent operations on the SessionRef object will return errors::Cancelled.
class SessionRef : public Session {
public:
- SessionRef(Session* session) : session_(session) {}
- virtual ~SessionRef() {}
+ explicit SessionRef(Session* session);
+ ~SessionRef() override;
Status Create(const GraphDef& graph) override;
Status Extend(const GraphDef& graph) override;
@@ -78,9 +80,12 @@ class SessionRef : public Session {
uint64 run_count_ GUARDED_BY(run_lock_) = {0};
std::shared_ptr<Session> session_;
+ // Borrowed reference to global session logger.
+ SessionLogger* logger_;
+
Status CheckNotClosed();
};
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
+#endif // TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 4afc6399d5..347833ce8f 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -61,6 +61,12 @@ from tensorflow.python.platform import googletest
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
+try:
+ import attr # pylint:disable=g-import-not-at-top
+except ImportError:
+ attr = None
+
+
# NOTE(mrry): Dummy shape registration for ops used in the tests, since they
# don't have C++ op registrations on which to attach C++ shape fns.
ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
@@ -114,11 +120,17 @@ class SessionTest(test_util.TensorFlowTestCase):
inp = constant_op.constant(10.0, name='W1')
self.assertAllEqual(inp.eval(), 10.0)
- devices = sess.list_devices()
- self.assertEqual(2, len(devices))
- for device in devices:
- self.assertEqual('CPU', framework_device_lib.DeviceSpec.from_string(
- device.name).device_type)
+ num_cpu_devices = 0
+ num_gpu_devices = 0
+ for device in sess.list_devices():
+ device_type = framework_device_lib.DeviceSpec.from_string(
+ device.name).device_type
+ if device_type == 'CPU':
+ num_cpu_devices += 1
+ elif device_type == 'GPU':
+ num_gpu_devices += 1
+ self.assertEqual(2, num_cpu_devices)
+ self.assertEqual(0, num_gpu_devices)
def testPerSessionThreads(self):
with session.Session(
@@ -300,6 +312,82 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(None, res[2])
self.assertEqual(44.0, res[1])
+ def testFetchAttrs(self):
+ if attr is None:
+ self.skipTest('attr module is unavailable.')
+
+ @attr.s
+ class SampleAttr(object):
+ field1 = attr.ib()
+ field2 = attr.ib()
+
+ val1 = np.array([1.2, 3.4, 5.6])
+ val2 = np.array([[1, 2], [4, 3]])
+ val3 = np.array([10, 20, 30])
+
+ t1 = constant_op.constant(val1)
+ t2 = constant_op.constant(val2)
+
+ sample = SampleAttr(t1, t2)
+ with session.Session() as sess:
+ result = sess.run(sample)
+ self.assertIsInstance(result, SampleAttr)
+ self.assertAllEqual(val1, result.field1)
+ self.assertAllEqual(val2, result.field2)
+
+ result = sess.run(sample, feed_dict={sample.field1: val3})
+ self.assertIsInstance(result, SampleAttr)
+ self.assertAllEqual(val3, result.field1)
+ self.assertAllEqual(val2, result.field2)
+
+ def testFetchNestedAttrs(self):
+ if attr is None:
+ self.skipTest('attr module is unavailable.')
+
+ @attr.s
+ class SampleAttr(object):
+ field0 = attr.ib()
+ field1 = attr.ib()
+
+ v1 = 10
+ v2 = 20
+ v3 = np.float32(1.2)
+ v4 = np.float32(3.4)
+ v5 = np.float64(100.001)
+ v6 = np.float64(-23.451)
+ arr1 = np.array([1.2, 6.7, 3.4])
+ arr2 = np.array([7, 11, 3])
+ sample = SampleAttr(
+ SampleAttr(
+ SampleAttr(constant_op.constant(v1), constant_op.constant(v2)),
+ SampleAttr(constant_op.constant(arr1), constant_op.constant(arr2))),
+ {'A': SampleAttr(constant_op.constant(v3), constant_op.constant(v4)),
+ 'B': [SampleAttr(constant_op.constant(v5), constant_op.constant(v6))]})
+
+ with session.Session() as sess:
+ result = sess.run(sample)
+ self.assertIsInstance(result, SampleAttr)
+ self.assertIsInstance(result.field0, SampleAttr)
+ self.assertIsInstance(result.field0.field0, SampleAttr)
+ self.assertIsInstance(result.field0.field1, SampleAttr)
+ self.assertIsInstance(result.field0.field1.field0, np.ndarray)
+ self.assertAllEqual(arr1, result.field0.field1.field0)
+ self.assertIsInstance(result.field0.field1.field1, np.ndarray)
+ self.assertAllEqual(arr2, result.field0.field1.field1)
+ self.assertIsInstance(result.field1, dict)
+ self.assertIn('A', result.field1)
+ self.assertIn('B', result.field1)
+ self.assertIsInstance(result.field1['A'], SampleAttr)
+ self.assertAllEqual(
+ [v3, v4],
+ [result.field1['A'].field0, result.field1['A'].field1])
+ self.assertIsInstance(result.field1['B'], list)
+ self.assertEqual(1, len(result.field1['B']))
+ self.assertIsInstance(result.field1['B'][0], SampleAttr)
+ self.assertAllEqual(
+ [v5, v6],
+ [result.field1['B'][0].field0, result.field1['B'][0].field1])
+
def testFetchNestingEmptyOneLevel(self):
with session.Session() as sess:
a_val = 11.0
@@ -940,7 +1028,7 @@ class SessionTest(test_util.TensorFlowTestCase):
with session.Session():
a = constant_op.constant(1.0, shape=[1, 2])
b = constant_op.constant(2.0, shape=[1, 2], name='b')
- v = variables.Variable(a, a.dtype)
+ v = variables.VariableV1(a, a.dtype)
assign_a_to_v = state_ops.assign(v, a)
assign_a_to_v.eval()
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 39a2922ac0..ef7527d887 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -463,7 +463,7 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
}
// Override default py3 behavior of attempting to encode into Unicode.
-%typemap(out) std::string tensorflow::GetResourceHandleShapeAndType {
+%typemap(out) std::string tensorflow::GetHandleShapeAndType {
$result = PyBytes_FromStringAndSize($1.data(), $1.size());
}
@@ -782,7 +782,7 @@ def TF_Reset(target, containers=None, config=None):
%unignore TF_TryEvaluateConstant_wrapper;
%noexception TF_TryEvaluateConstant_wrapper;
%unignore ExtendSession;
-%unignore ResourceHandleShapeAndType;
+%unignore HandleShapeAndType;
%include "tensorflow/python/client/tf_session_helper.h"
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index bcd4af2912..dc0c10bab7 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
-#include "tensorflow/core/common_runtime/session_ref.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
@@ -31,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/equal_graph_def.h"
+#include "tensorflow/python/client/session_ref.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
diff --git a/tensorflow/python/client/timeline.py b/tensorflow/python/client/timeline.py
index 1e96ac5ed4..c3f38294b5 100644
--- a/tensorflow/python/client/timeline.py
+++ b/tensorflow/python/client/timeline.py
@@ -588,7 +588,8 @@ class Timeline(object):
alloc_tensor_set = set()
alloc_maxes[allocator] = AllocationMaximum(
timestamp=0, num_bytes=0, tensors=set())
- for time, num_bytes, name in alloc_list:
+ for time, num_bytes, name in sorted(
+ alloc_list, key=lambda allocation: allocation[0]):
total_bytes += num_bytes
if num_bytes < 0:
alloc_tensor_set.discard(name)
diff --git a/tensorflow/python/client/timeline_test.py b/tensorflow/python/client/timeline_test.py
index 281d7f2e2b..032bbf7c4e 100644
--- a/tensorflow/python/client/timeline_test.py
+++ b/tensorflow/python/client/timeline_test.py
@@ -134,7 +134,7 @@ class TimelineTest(test.TestCase):
ctf = tl.generate_chrome_trace_format()
self._validateTrace(ctf)
- def disabled_testAnalysisAndAllocations(self):
+ def testAnalysisAndAllocations(self):
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
@@ -163,8 +163,6 @@ class TimelineTest(test.TestCase):
# At least num1 + num2, both float32s (4 bytes each)
self.assertGreaterEqual(cpu_max.num_bytes, 8)
self.assertGreater(cpu_max.timestamp, 0)
- self.assertTrue('num1' in cpu_max.tensors or 'num1/read' in cpu_max.tensors)
- self.assertTrue('num2' in cpu_max.tensors or 'num2/read' in cpu_max.tensors)
def testManyCPUs(self):
run_options = config_pb2.RunOptions(
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index c246a98237..b74fce3a4c 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 16)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 28)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/data/BUILD b/tensorflow/python/data/BUILD
index 3e08c1587e..138141f4fc 100644
--- a/tensorflow/python/data/BUILD
+++ b/tensorflow/python/data/BUILD
@@ -12,6 +12,7 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:multi_device_iterator_ops",
"//tensorflow/python/data/ops:readers",
],
)
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 631b87a718..5f9818566f 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -15,6 +15,7 @@ tf_py_test(
size = "small",
srcs = ["batch_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
@@ -31,10 +32,44 @@ tf_py_test(
)
tf_py_test(
+ name = "cache_dataset_op_test",
+ size = "small",
+ srcs = ["cache_dataset_op_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+)
+
+tf_py_test(
+ name = "concatenate_dataset_op_test",
+ size = "small",
+ srcs = ["concatenate_dataset_op_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+tf_py_test(
name = "dataset_constructor_op_test",
size = "small",
srcs = ["dataset_constructor_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
@@ -63,6 +98,7 @@ tf_py_test(
size = "medium",
srcs = ["dataset_from_generator_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
@@ -78,6 +114,7 @@ tf_py_test(
size = "small",
srcs = ["dataset_ops_test.py"],
additional_deps = [
+ ":test_base",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python/data/ops:dataset_ops",
@@ -89,6 +126,7 @@ tf_py_test(
size = "small",
srcs = ["filter_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -106,6 +144,7 @@ tf_py_test(
size = "small",
srcs = ["flat_map_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
@@ -123,6 +162,7 @@ tf_py_test(
size = "small",
srcs = ["list_files_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
@@ -133,10 +173,25 @@ tf_py_test(
)
tf_py_test(
+ name = "inputs_test",
+ size = "small",
+ srcs = ["inputs_test.py"],
+ additional_deps = [
+ ":test_base",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+tf_py_test(
name = "interleave_dataset_op_test",
size = "small",
srcs = ["interleave_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
@@ -151,11 +206,80 @@ tf_py_test(
],
)
+cuda_py_test(
+ name = "iterator_ops_test",
+ size = "small",
+ srcs = ["iterator_ops_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/training/checkpointable:util",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:training",
+ "//tensorflow/python/compat:compat",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variables",
+ ],
+ grpc_enabled = True,
+)
+
+tf_py_test(
+ name = "iterator_ops_cluster_test",
+ size = "small",
+ srcs = ["iterator_ops_cluster_test.py"],
+ additional_deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:lookup_ops",
+ ],
+ grpc_enabled = True,
+ tags = [
+ "no_oss", # Test flaky due to port collisions.
+ "no_windows",
+ ],
+)
+
tf_py_test(
name = "map_dataset_op_test",
size = "small",
srcs = ["map_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
@@ -177,11 +301,54 @@ tf_py_test(
],
)
+cuda_py_test(
+ name = "multi_device_iterator_test",
+ size = "small",
+ srcs = ["multi_device_iterator_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:multi_device_iterator_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ ],
+ tags = [
+ "no_windows_gpu",
+ ],
+)
+
+cuda_py_test(
+ name = "optional_ops_test",
+ size = "small",
+ srcs = ["optional_ops_test.py"],
+ additional_deps = [
+ ":test_base",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:optional_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:tensor_shape",
+ ],
+)
+
tf_py_test(
name = "prefetch_dataset_op_test",
size = "small",
srcs = ["prefetch_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -197,6 +364,7 @@ tf_py_test(
size = "small",
srcs = ["range_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dataset_ops_gen",
@@ -218,6 +386,7 @@ tf_py_test(
size = "small",
srcs = ["reader_dataset_ops_test.py"],
additional_deps = [
+ ":test_base",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -236,32 +405,35 @@ tf_py_test(
)
tf_py_test(
- name = "sequence_dataset_op_test",
+ name = "reduce_dataset_op_test",
size = "small",
- srcs = ["sequence_dataset_op_test.py"],
+ srcs = ["reduce_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
+ "@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
],
)
tf_py_test(
- name = "shuffle_dataset_op_test",
+ name = "sequence_dataset_op_test",
size = "small",
- srcs = ["shuffle_dataset_op_test.py"],
+ srcs = ["sequence_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
],
)
@@ -270,6 +442,7 @@ tf_py_test(
size = "small",
srcs = ["shard_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
@@ -277,133 +450,59 @@ tf_py_test(
)
tf_py_test(
- name = "cache_dataset_op_test",
+ name = "shuffle_dataset_op_test",
size = "small",
- srcs = ["cache_dataset_op_test.py"],
+ srcs = ["shuffle_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
],
)
-tf_py_test(
- name = "zip_dataset_op_test",
- size = "small",
- srcs = ["zip_dataset_op_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow/python:array_ops",
+py_library(
+ name = "test_base",
+ srcs = ["test_base.py"],
+ deps = [
"//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/ops:dataset_ops",
],
)
tf_py_test(
- name = "concatenate_dataset_op_test",
- size = "small",
- srcs = ["concatenate_dataset_op_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- ],
-)
-
-cuda_py_test(
- name = "iterator_ops_test",
+ name = "window_dataset_op_test",
size = "small",
- srcs = ["iterator_ops_test.py"],
+ srcs = ["window_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
+ "@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
- "//tensorflow/python/data/ops:readers",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/util:sparse",
- "//tensorflow/python/eager:context",
- "//tensorflow/python/training/checkpointable:util",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:function",
- "//tensorflow/python:functional_ops",
- "//tensorflow/python:gradients",
- "//tensorflow/python:io_ops",
"//tensorflow/python:math_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:script_ops",
- "//tensorflow/python:session",
"//tensorflow/python:sparse_tensor",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python:training",
- "//tensorflow/python/compat:compat",
- "//tensorflow/python:util",
- "//tensorflow/python:variables",
+ "//tensorflow/python/data/ops:dataset_ops",
],
- grpc_enabled = True,
)
tf_py_test(
- name = "iterator_ops_cluster_test",
+ name = "zip_dataset_op_test",
size = "small",
- srcs = ["iterator_ops_cluster_test.py"],
+ srcs = ["zip_dataset_op_test.py"],
additional_deps = [
- "//tensorflow/core:protos_all_py",
+ ":test_base",
+ "//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:function",
- "//tensorflow/python:functional_ops",
- "//tensorflow/python:session",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:lookup_ops",
- ],
- grpc_enabled = True,
- tags = [
- "no_oss", # Test flaky due to port collisions.
- "no_windows",
- ],
-)
-
-cuda_py_test(
- name = "optional_ops_test",
- size = "small",
- srcs = ["optional_ops_test.py"],
- additional_deps = [
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/ops:optional_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:tensor_shape",
],
)
diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
index c48708a2b9..9cb4daf284 100644
--- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
@@ -24,6 +24,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -37,7 +38,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class BatchDatasetTest(test.TestCase, parameterized.TestCase):
+class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
('even', 28, 14, False),
@@ -115,11 +116,6 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testBatchSparse(self):
def _sparse(i):
@@ -227,7 +223,7 @@ def _random_seq_lens(count):
return np.random.randint(20, size=(count,)).astype(np.int32)
-class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
+class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
('default_padding', _random_seq_lens(32), 4, [-1], False),
diff --git a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
index d5f5b2fe05..63625fac03 100644
--- a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
@@ -23,6 +23,7 @@ import tempfile
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -34,7 +35,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-class FileCacheDatasetTest(test.TestCase):
+class FileCacheDatasetTest(test_base.DatasetTestBase):
def setUp(self):
self.tmp_dir = tempfile.mkdtemp()
@@ -200,7 +201,7 @@ class FileCacheDatasetTest(test.TestCase):
self.assertAllEqual(elements, elements_itr2)
-class MemoryCacheDatasetTest(test.TestCase):
+class MemoryCacheDatasetTest(test_base.DatasetTestBase):
def testCacheDatasetPassthrough(self):
with ops.device("cpu:0"):
diff --git a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
index 5dfb84f28e..83af31f380 100644
--- a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import errors
@@ -26,7 +27,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import test
-class ConcatenateDatasetTest(test.TestCase):
+class ConcatenateDatasetTest(test_base.DatasetTestBase):
def testConcatenateDataset(self):
input_components = (
diff --git a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
index e43564a2eb..bc6b36285a 100644
--- a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
@@ -23,6 +23,7 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -36,7 +37,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
-class DatasetConstructorTest(test.TestCase):
+class DatasetConstructorTest(test_base.DatasetTestBase):
def testFromTensors(self):
"""Test a dataset that represents a single tuple of tensors."""
@@ -58,11 +59,6 @@ class DatasetConstructorTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testFromTensorsSparse(self):
"""Test a dataset that represents a single tuple of tensors."""
components = (sparse_tensor.SparseTensorValue(
diff --git a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
index cd0c1ddf1e..cb8cb9a77d 100644
--- a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
@@ -22,6 +22,7 @@ import threading
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -30,7 +31,7 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
-class DatasetConstructorTest(test.TestCase):
+class DatasetConstructorTest(test_base.DatasetTestBase):
def _testFromGenerator(self, generator, elem_sequence, num_repeats,
output_types=None):
diff --git a/tensorflow/python/data/kernel_tests/dataset_ops_test.py b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
index 239aa85175..f115f9d9c7 100644
--- a/tensorflow/python/data/kernel_tests/dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
@@ -19,11 +19,12 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import graph_pb2
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
-class DatasetOpsTest(test.TestCase):
+class DatasetOpsTest(test_base.DatasetTestBase):
def testAsSerializedGraph(self):
dataset = dataset_ops.Dataset.range(10)
diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
index 19944d389f..6b7afafa5d 100644
--- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
@@ -22,6 +22,7 @@ import time
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -33,7 +34,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class FilterDatasetTest(test.TestCase):
+class FilterDatasetTest(test_base.DatasetTestBase):
def testFilterDataset(self):
components = (
@@ -129,11 +130,6 @@ class FilterDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testSparse(self):
def _map_fn(i):
diff --git a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
index 1123cbff62..68038f9cfc 100644
--- a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
@@ -22,6 +22,7 @@ import random
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
@@ -30,7 +31,7 @@ from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
-class FlatMapDatasetTest(test.TestCase):
+class FlatMapDatasetTest(test_base.DatasetTestBase):
# pylint: disable=g-long-lambda
def testFlatMapDataset(self):
diff --git a/tensorflow/python/data/kernel_tests/inputs_test.py b/tensorflow/python/data/kernel_tests/inputs_test.py
new file mode 100644
index 0000000000..d089b49bcc
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/inputs_test.py
@@ -0,0 +1,149 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.platform import test
+
+
+class InputsTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @staticmethod
+ def make_apply_fn(dataset):
+
+ def apply_fn(dataset):
+
+ def _apply_fn(dataset):
+ return dataset.cache()
+
+ return dataset.apply(_apply_fn)
+
+ return apply_fn
+
+ @staticmethod
+ def make_gen():
+
+ def gen():
+ yield 42
+
+ return gen
+
+ @staticmethod
+ def make_interleave_fn(dataset, num_parallel_calls=None):
+
+ def interleave_fn(dataset):
+ return dataset.interleave(
+ lambda x: dataset_ops.Dataset.range(0),
+ cycle_length=2,
+ num_parallel_calls=num_parallel_calls)
+
+ return interleave_fn
+
+ @parameterized.named_parameters(
+ ("FixedLengthRecord", readers.FixedLengthRecordDataset("", 42)),
+ ("FromGenerator",
+ dataset_ops.Dataset.from_generator(make_gen.__func__(), dtypes.int32),
+ 1),
+ ("FromSparseTensorSlices",
+ dataset_ops.Dataset.from_sparse_tensor_slices(
+ sparse_tensor.SparseTensor(
+ indices=np.array([[0, 0], [1, 0], [2, 0]]),
+ values=np.array([0, 0, 0]),
+ dense_shape=np.array([3, 1])))),
+ ("FromTensors", dataset_ops.Dataset.from_tensors([42])),
+ ("FromTensorSlices", dataset_ops.Dataset.from_tensors([42])),
+ ("Range", dataset_ops.Dataset.range(10)),
+ ("TextLine", readers.TextLineDataset("")),
+ ("TFRecord", readers.TFRecordDataset(""), 1),
+ )
+ def testDatasetSourceInputs(self, dataset, num_inputs=0):
+ self.assertEqual(num_inputs, len(dataset._inputs()))
+
+ @parameterized.named_parameters(
+ ("Apply", make_apply_fn.__func__(dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Batch", lambda x: x.batch(10), dataset_ops.Dataset.range(0)),
+ ("Cache", lambda x: x.cache(), dataset_ops.Dataset.range(0)),
+ ("Filter", lambda x: x.filter(lambda x: True),
+ dataset_ops.Dataset.range(0)),
+ ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Interleave", make_interleave_fn.__func__(dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Map", lambda x: x.map(lambda x: x), dataset_ops.Dataset.range(0)),
+ ("PaddedBatch", lambda x: x.padded_batch(10, []),
+ dataset_ops.Dataset.range(0)),
+ ("ParallelInterleave",
+ make_interleave_fn.__func__(dataset_ops.Dataset.range(0), 2),
+ dataset_ops.Dataset.range(0)),
+ ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2),
+ dataset_ops.Dataset.range(0)),
+ ("Repeat", lambda x: x.repeat(), dataset_ops.Dataset.range(0)),
+ ("Shuffle", lambda x: x.shuffle(10), dataset_ops.Dataset.range(0)),
+ ("Skip", lambda x: x.skip(1), dataset_ops.Dataset.range(0)),
+ ("Take", lambda x: x.take(1), dataset_ops.Dataset.range(0)),
+ ("Window", lambda x: x.window(10), dataset_ops.Dataset.range(0)),
+ )
+ def testUnaryTransformationInputs(self, dataset_fn, input_dataset):
+ self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())
+
+ @parameterized.named_parameters(
+ ("Concatenate", lambda x, y: x.concatenate(y),
+ dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1)))
+ def testBinaryTransformationInputs(self, dataset_fn, input1, input2):
+ self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())
+
+ @parameterized.named_parameters(
+ ("ZipOne", dataset_ops.Dataset.zip, (dataset_ops.Dataset.range(0))),
+ ("ZipNest", dataset_ops.Dataset.zip,
+ (dataset_ops.Dataset.range(0),
+ (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))),
+ ("ZipTuple", dataset_ops.Dataset.zip,
+ (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))))
+ def testVariadicTransformationInputs(self, dataset_fn, input_datasets):
+ self.assertEqual(
+ nest.flatten(input_datasets),
+ dataset_fn(input_datasets)._inputs())
+
+ def testCollectInputs(self):
+ ds1 = dataset_ops.Dataset.range(0)
+ ds2 = ds1.concatenate(ds1)
+ ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))
+
+ inputs = []
+ queue = [ds3]
+ while queue:
+ ds = queue[0]
+ queue = queue[1:]
+ queue.extend(ds._inputs())
+ inputs.append(ds)
+
+ self.assertEqual(5, inputs.count(ds1))
+ self.assertEqual(2, inputs.count(ds2))
+ self.assertEqual(1, inputs.count(ds3))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
index a35cee594a..92bb67b6ff 100644
--- a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
@@ -22,6 +22,7 @@ import itertools
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
@@ -30,7 +31,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class InterleaveDatasetTest(test.TestCase, parameterized.TestCase):
+class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def _interleave(self, lists, cycle_length, block_length):
num_open = 0
@@ -134,7 +135,7 @@ class InterleaveDatasetTest(test.TestCase, parameterized.TestCase):
result.append([value] * value)
return result * count
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for expected_element in self._interleave(
repeat(input_values, count), cycle_length, block_length):
self.assertEqual(expected_element, sess.run(get_next))
@@ -169,7 +170,7 @@ class InterleaveDatasetTest(test.TestCase, parameterized.TestCase):
num_parallel_calls)
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for value in input_values:
if np.isnan(value):
with self.assertRaises(errors.InvalidArgumentError):
@@ -195,7 +196,7 @@ class InterleaveDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
for j in range(2):
diff --git a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
index c4b338a58f..8eb13815d4 100644
--- a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
@@ -22,6 +22,7 @@ from os import path
import shutil
import tempfile
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -30,7 +31,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class ListFilesDatasetOpTest(test.TestCase):
+class ListFilesDatasetOpTest(test_base.DatasetTestBase):
def setUp(self):
self.tmp_dir = tempfile.mkdtemp()
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index 7685d8dbdc..230ae3f3fd 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -27,6 +27,7 @@ import numpy as np
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -47,7 +48,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
-class MapDatasetTest(test.TestCase, parameterized.TestCase):
+class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def _buildMapDataset(self, components, count):
def _map_fn(x, y, z):
@@ -397,6 +398,28 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
# Randomness is repeatable given same seed
self.assertAllClose(random_values, random_values_2)
+ def testStatefulMapKeepsStateAcrossIterators(self):
+ iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10)
+ .map(lambda _: random_ops.random_uniform((), seed=11))
+ .repeat(1000)
+ .batch(10)
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ random_values = sess.run(get_next)
+
+ # Assert that one of the next 99 batches yielded by the iterator is
+ # different from the first.
+ i = 0
+ while i < 99:
+ if np.any(random_values != sess.run(get_next)):
+ break
+ i += 1
+ self.assertLess(i, 99)
+
def testMapDict(self):
iterator = (dataset_ops.Dataset.range(10)
.map(lambda x: {"foo": x * 2, "bar": x ** 2})
@@ -552,11 +575,6 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testSparse(self):
def _sparse(i):
@@ -731,7 +749,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tids = sess.run(get_next)
self.assertTrue(all(tids[0] == tid for tid in tids))
# pylint: enable=g-long-lambda
diff --git a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
new file mode 100644
index 0000000000..1cf6dd1bea
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
@@ -0,0 +1,191 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""MultiDeviceIterator tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import multi_device_iterator_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class MultiDeviceIteratorTest(test_base.DatasetTestBase):
+
+ def testNoGetNext(self):
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"])
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+
+ def testBasic(self):
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"])
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 10, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+ def testOneOnSameDevice(self):
+ with ops.device("/cpu:0"):
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:0", "/cpu:1"])
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 10, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+ def testRepeatDevices(self):
+ with ops.device("/cpu:0"):
+ dataset = dataset_ops.Dataset.range(20)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2", "/cpu:1", "/cpu:2"])
+ elements = multi_device_iterator.get_next()
+ elem_on_1, elem_on_2, elem_on_3, elem_on_4 = elements
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 20, 4):
+ self.assertEqual(i, sess.run(elem_on_1))
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ self.assertEqual(i + 2, sess.run(elem_on_3))
+ self.assertEqual(i + 3, sess.run(elem_on_4))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+ sess.run(elem_on_3)
+ sess.run(elem_on_4)
+
+ def testNotFullyDivisible(self):
+ dataset = dataset_ops.Dataset.range(9)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"])
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 8, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ self.assertEqual(8, sess.run(elem_on_1))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+ def testUneven(self):
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"], max_buffer_size=4)
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 10, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ for i in range(0, 10, 2):
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+ def testMultipleInitializations(self):
+ with ops.device("/cpu:0"):
+ epoch = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset1 = dataset_ops.Dataset.from_tensors(epoch).repeat(1000)
+ dataset2 = dataset_ops.Dataset.range(1000)
+ dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4)
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+ init_op = multi_device_iterator.initializer
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ for i in range(1000):
+ sess.run(init_op, feed_dict={epoch: i})
+ self.assertEqual([(i, 0), (i, 1)], sess.run([elem_on_1, elem_on_2]))
+
+ def testBasicGpu(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/gpu:0"])
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 10, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+ def testUnevenGpu(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/gpu:0"], max_buffer_size=4)
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 10, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ for i in range(0, 10, 2):
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/optional_ops_test.py b/tensorflow/python/data/kernel_tests/optional_ops_test.py
index c344513e71..604e3ad88e 100644
--- a/tensorflow/python/data/kernel_tests/optional_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/optional_ops_test.py
@@ -17,11 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import optional_ops
+from tensorflow.python.data.util import structure
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -33,14 +36,11 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class OptionalTest(test.TestCase):
+class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def testFromValue(self):
opt = optional_ops.Optional.from_value(constant_op.constant(37.0))
- self.assertEqual(dtypes.float32, opt.output_types)
- self.assertEqual([], opt.output_shapes)
- self.assertEqual(ops.Tensor, opt.output_classes)
self.assertTrue(self.evaluate(opt.has_value()))
self.assertEqual(37.0, self.evaluate(opt.get_value()))
@@ -50,15 +50,6 @@ class OptionalTest(test.TestCase):
"a": constant_op.constant(37.0),
"b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
})
- self.assertEqual({
- "a": dtypes.float32,
- "b": (dtypes.string, dtypes.string)
- }, opt.output_types)
- self.assertEqual({"a": [], "b": ([1], [])}, opt.output_shapes)
- self.assertEqual({
- "a": ops.Tensor,
- "b": (ops.Tensor, ops.Tensor)
- }, opt.output_classes)
self.assertTrue(self.evaluate(opt.has_value()))
self.assertEqual({
"a": 37.0,
@@ -76,46 +67,29 @@ class OptionalTest(test.TestCase):
values=np.array([-1., 1.], dtype=np.float32),
dense_shape=np.array([2, 2]))
opt = optional_ops.Optional.from_value((st_0, st_1))
- self.assertEqual((dtypes.int64, dtypes.float32), opt.output_types)
- self.assertEqual(([1], [2, 2]), opt.output_shapes)
- self.assertEqual((sparse_tensor.SparseTensor, sparse_tensor.SparseTensor),
- opt.output_classes)
+ self.assertTrue(self.evaluate(opt.has_value()))
+ val_0, val_1 = opt.get_value()
+ for expected, actual in [(st_0, val_0), (st_1, val_1)]:
+ self.assertAllEqual(expected.indices, self.evaluate(actual.indices))
+ self.assertAllEqual(expected.values, self.evaluate(actual.values))
+ self.assertAllEqual(expected.dense_shape,
+ self.evaluate(actual.dense_shape))
@test_util.run_in_graph_and_eager_modes
def testFromNone(self):
- opt = optional_ops.Optional.none_from_structure(tensor_shape.scalar(),
- dtypes.float32, ops.Tensor)
- self.assertEqual(dtypes.float32, opt.output_types)
- self.assertEqual([], opt.output_shapes)
- self.assertEqual(ops.Tensor, opt.output_classes)
+ value_structure = structure.TensorStructure(dtypes.float32, [])
+ opt = optional_ops.Optional.none_from_structure(value_structure)
+ self.assertTrue(opt.value_structure.is_compatible_with(value_structure))
+ self.assertFalse(
+ opt.value_structure.is_compatible_with(
+ structure.TensorStructure(dtypes.float32, [1])))
+ self.assertFalse(
+ opt.value_structure.is_compatible_with(
+ structure.TensorStructure(dtypes.int32, [])))
self.assertFalse(self.evaluate(opt.has_value()))
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(opt.get_value())
- def testStructureMismatchError(self):
- tuple_output_shapes = (tensor_shape.scalar(), tensor_shape.scalar())
- tuple_output_types = (dtypes.float32, dtypes.float32)
- tuple_output_classes = (ops.Tensor, ops.Tensor)
-
- dict_output_shapes = {
- "a": tensor_shape.scalar(),
- "b": tensor_shape.scalar()
- }
- dict_output_types = {"a": dtypes.float32, "b": dtypes.float32}
- dict_output_classes = {"a": ops.Tensor, "b": ops.Tensor}
-
- with self.assertRaises(TypeError):
- optional_ops.Optional.none_from_structure(
- tuple_output_shapes, tuple_output_types, dict_output_classes)
-
- with self.assertRaises(TypeError):
- optional_ops.Optional.none_from_structure(
- tuple_output_shapes, dict_output_types, tuple_output_classes)
-
- with self.assertRaises(TypeError):
- optional_ops.Optional.none_from_structure(
- dict_output_shapes, tuple_output_types, tuple_output_classes)
-
@test_util.run_in_graph_and_eager_modes
def testCopyToGPU(self):
if not test_util.is_gpu_available():
@@ -126,17 +100,15 @@ class OptionalTest(test.TestCase):
(constant_op.constant(37.0), constant_op.constant("Foo"),
constant_op.constant(42)))
optional_none = optional_ops.Optional.none_from_structure(
- tensor_shape.scalar(), dtypes.float32, ops.Tensor)
+ structure.TensorStructure(dtypes.float32, []))
with ops.device("/gpu:0"):
gpu_optional_with_value = optional_ops._OptionalImpl(
array_ops.identity(optional_with_value._variant_tensor),
- optional_with_value.output_shapes, optional_with_value.output_types,
- optional_with_value.output_classes)
+ optional_with_value.value_structure)
gpu_optional_none = optional_ops._OptionalImpl(
array_ops.identity(optional_none._variant_tensor),
- optional_none.output_shapes, optional_none.output_types,
- optional_none.output_classes)
+ optional_none.value_structure)
gpu_optional_with_value_has_value = gpu_optional_with_value.has_value()
gpu_optional_with_value_values = gpu_optional_with_value.get_value()
@@ -148,14 +120,101 @@ class OptionalTest(test.TestCase):
self.evaluate(gpu_optional_with_value_values))
self.assertFalse(self.evaluate(gpu_optional_none_has_value))
- def testIteratorGetNextAsOptional(self):
- ds = dataset_ops.Dataset.range(3)
+ def _assertElementValueEqual(self, expected, actual):
+ if isinstance(expected, dict):
+ self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
+ for k in expected.keys():
+ self._assertElementValueEqual(expected[k], actual[k])
+ elif isinstance(expected, sparse_tensor.SparseTensorValue):
+ self.assertAllEqual(expected.indices, actual.indices)
+ self.assertAllEqual(expected.values, actual.values)
+ self.assertAllEqual(expected.dense_shape, actual.dense_shape)
+ else:
+ self.assertAllEqual(expected, actual)
+
+ # pylint: disable=g-long-lambda
+ @parameterized.named_parameters(
+ ("Tensor", lambda: constant_op.constant(37.0),
+ structure.TensorStructure(dtypes.float32, [])),
+ ("SparseTensor", lambda: sparse_tensor.SparseTensor(
+ indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32),
+ dense_shape=[1]),
+ structure.SparseTensorStructure(dtypes.int32, [1])),
+ ("Nest", lambda: {
+ "a": constant_op.constant(37.0),
+ "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))},
+ structure.NestedStructure({
+ "a": structure.TensorStructure(dtypes.float32, []),
+ "b": (structure.TensorStructure(dtypes.string, [1]),
+ structure.TensorStructure(dtypes.string, []))})),
+ ("Optional", lambda: optional_ops.Optional.from_value(37.0),
+ optional_ops.OptionalStructure(
+ structure.TensorStructure(dtypes.float32, []))),
+ )
+ def testOptionalStructure(self, tf_value_fn, expected_value_structure):
+ tf_value = tf_value_fn()
+ opt = optional_ops.Optional.from_value(tf_value)
+
+ self.assertTrue(
+ expected_value_structure.is_compatible_with(opt.value_structure))
+ self.assertTrue(
+ opt.value_structure.is_compatible_with(expected_value_structure))
+
+ opt_structure = structure.Structure.from_value(opt)
+ self.assertIsInstance(opt_structure, optional_ops.OptionalStructure)
+ self.assertTrue(opt_structure.is_compatible_with(opt_structure))
+ self.assertTrue(opt_structure._value_structure.is_compatible_with(
+ expected_value_structure))
+ self.assertEqual([dtypes.variant], opt_structure._flat_types)
+ self.assertEqual([tensor_shape.scalar()], opt_structure._flat_shapes)
+
+ # All OptionalStructure objects are not compatible with a non-optional
+ # value.
+ non_optional_structure = structure.Structure.from_value(
+ constant_op.constant(42.0))
+ self.assertFalse(opt_structure.is_compatible_with(non_optional_structure))
+
+ # Assert that the optional survives a round-trip via _from_tensor_list()
+ # and _to_tensor_list().
+ round_trip_opt = opt_structure._from_tensor_list(
+ opt_structure._to_tensor_list(opt))
+ if isinstance(tf_value, optional_ops.Optional):
+ self.assertEqual(
+ self.evaluate(tf_value.get_value()),
+ self.evaluate(round_trip_opt.get_value().get_value()))
+ else:
+ self.assertEqual(
+ self.evaluate(tf_value), self.evaluate(round_trip_opt.get_value()))
+
+ @parameterized.named_parameters(
+ ("Tensor", np.array([1, 2, 3], dtype=np.int32),
+ lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True),
+ ("SparseTensor", sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 1]],
+ values=np.array([-1., 1.], dtype=np.float32), dense_shape=[2, 2]),
+ lambda: sparse_tensor.SparseTensor(
+ indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]),
+ False),
+ ("Nest", {"a": np.array([1, 2, 3], dtype=np.int32),
+ "b": sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 1]],
+ values=np.array([-1., 1.], dtype=np.float32),
+ dense_shape=[2, 2])},
+ lambda: {"a": constant_op.constant([4, 5, 6], dtype=dtypes.int32),
+ "b": sparse_tensor.SparseTensor(
+ indices=[[0, 1], [1, 0]], values=[37.0, 42.0],
+ dense_shape=[2, 2])}, False),
+ )
+ def testIteratorGetNextAsOptional(self, np_value, tf_value_fn, works_on_gpu):
+ if not works_on_gpu and test.is_gpu_available():
+ self.skipTest("Test case not yet supported on GPU.")
+ ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)
iterator = ds.make_initializable_iterator()
next_elem = iterator_ops.get_next_as_optional(iterator)
- self.assertTrue(isinstance(next_elem, optional_ops.Optional))
- self.assertEqual(ds.output_types, next_elem.output_types)
- self.assertEqual(ds.output_shapes, next_elem.output_shapes)
- self.assertEqual(ds.output_classes, next_elem.output_classes)
+ self.assertIsInstance(next_elem, optional_ops.Optional)
+ self.assertTrue(
+ next_elem.value_structure.is_compatible_with(
+ structure.Structure.from_value(tf_value_fn())))
elem_has_value_t = next_elem.has_value()
elem_value_t = next_elem.get_value()
with self.cached_session() as sess:
@@ -169,10 +228,10 @@ class OptionalTest(test.TestCase):
# For each element of the dataset, assert that the optional evaluates to
# the expected value.
sess.run(iterator.initializer)
- for i in range(3):
+ for _ in range(3):
elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
self.assertTrue(elem_has_value)
- self.assertEqual(i, elem_value)
+ self._assertElementValueEqual(np_value, elem_value)
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
# false, and attempting to get the value will fail.
diff --git a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
index cc97bac609..76e2697b29 100644
--- a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
from absl.testing import parameterized
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class PrefetchDatasetTest(test.TestCase, parameterized.TestCase):
+class PrefetchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.parameters((-1), (0), (5))
def testBufferSize(self, buffer_size):
diff --git a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
index 51e90785e7..b7e2a5f615 100644
--- a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import os
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import dtypes
@@ -34,7 +35,7 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
-class RangeDatasetTest(test.TestCase):
+class RangeDatasetTest(test_base.DatasetTestBase):
def tearDown(self):
# Remove all checkpoint files.
diff --git a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
index aa3636364d..aef2dd1d9c 100644
--- a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
@@ -21,6 +21,7 @@ import gzip
import os
import zlib
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers
@@ -46,7 +47,7 @@ except ImportError:
psutil_import_succeeded = False
-class TextLineDatasetTest(test.TestCase):
+class TextLineDatasetTest(test_base.DatasetTestBase):
def _lineText(self, f, l):
return compat.as_bytes("%d: %d" % (f, l))
@@ -199,7 +200,7 @@ class TextLineDatasetTest(test.TestCase):
self.assertNotIn(filename, [open_file.path for open_file in open_files])
-class FixedLengthRecordReaderTest(test.TestCase):
+class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
def setUp(self):
super(FixedLengthRecordReaderTest, self).setUp()
@@ -621,7 +622,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
sess.run(get_next_op)
-class TFRecordDatasetTest(test.TestCase):
+class TFRecordDatasetTest(test_base.DatasetTestBase):
def setUp(self):
super(TFRecordDatasetTest, self).setUp()
diff --git a/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py b/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py
new file mode 100644
index 0000000000..11e07300b9
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py
@@ -0,0 +1,124 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ def testSum(self):
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1)
+ result = ds.reduce(np.int64(0), lambda x, y: x + y)
+ with self.cached_session() as sess:
+ self.assertEqual(((i + 1) * i) // 2, sess.run(result))
+
+ def testSumTuple(self):
+
+ def reduce_fn(state, value):
+ v1, v2 = value
+ return state + v1 + v2
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1)
+ ds = dataset_ops.Dataset.zip((ds, ds))
+ result = ds.reduce(np.int64(0), reduce_fn)
+ with self.cached_session() as sess:
+ self.assertEqual(((i + 1) * i), sess.run(result))
+
+ def testSumAndCount(self):
+
+ def reduce_fn(state, value):
+ s, c = state
+ return s + value, c + 1
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1)
+ result = ds.reduce((np.int64(0), np.int64(0)), reduce_fn)
+ with self.cached_session() as sess:
+ s, c = sess.run(result)
+ self.assertEqual(((i + 1) * i) // 2, s)
+ self.assertEqual(i, c)
+
+ def testSquareUsingPlaceholder(self):
+ delta = array_ops.placeholder(dtype=dtypes.int64)
+
+ def reduce_fn(state, _):
+ return state + delta
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1)
+ result = ds.reduce(np.int64(0), reduce_fn)
+ with self.cached_session() as sess:
+ square = sess.run(result, feed_dict={delta: i})
+ self.assertEqual(i * i, square)
+
+ def testSparse(self):
+
+ def reduce_fn(_, value):
+ return value
+
+ def make_sparse_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1])),
+ dense_shape=np.array([1, 1]))
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.from_tensors(make_sparse_fn(i+1))
+ result = ds.reduce(make_sparse_fn(0), reduce_fn)
+ with self.cached_session() as sess:
+ self.assertSparseValuesEqual(make_sparse_fn(i+1), sess.run(result))
+
+ def testNested(self):
+
+ def reduce_fn(state, value):
+ state["dense"] += value["dense"]
+ state["sparse"] = value["sparse"]
+ return state
+
+ def make_sparse_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1])),
+ dense_shape=np.array([1, 1]))
+
+ def map_fn(i):
+ return {"dense": math_ops.cast(i, dtype=dtypes.int64),
+ "sparse": make_sparse_fn(math_ops.cast(i, dtype=dtypes.int64))}
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1).map(map_fn)
+ result = ds.reduce(map_fn(0), reduce_fn)
+ with self.cached_session() as sess:
+ result = sess.run(result)
+ self.assertEqual(((i + 1) * i) // 2, result["dense"])
+ self.assertSparseValuesEqual(make_sparse_fn(i), result["sparse"])
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
index 37e2333560..e86356dee7 100644
--- a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class SequenceDatasetTest(test.TestCase):
+class SequenceDatasetTest(test_base.DatasetTestBase):
def testRepeatTensorDataset(self):
"""Test a dataset that repeats its input multiple times."""
diff --git a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
index 137f6341ce..b9f3c79da5 100644
--- a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
@@ -17,12 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
-class ShardDatasetOpTest(test.TestCase):
+class ShardDatasetOpTest(test_base.DatasetTestBase):
def testSimpleCase(self):
dataset = dataset_ops.Dataset.range(10).shard(5, 2)
diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
index f294840706..347af18576 100644
--- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
@@ -21,6 +21,7 @@ import collections
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -30,7 +31,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ShuffleDatasetTest(test.TestCase):
+class ShuffleDatasetTest(test_base.DatasetTestBase):
def testShuffleDataset(self):
components = (
diff --git a/tensorflow/contrib/data/python/ops/contrib_op_loader.py b/tensorflow/python/data/kernel_tests/test_base.py
index 8f495a9dc9..b4f64115b7 100644
--- a/tensorflow/contrib/data/python/ops/contrib_op_loader.py
+++ b/tensorflow/python/data/kernel_tests/test_base.py
@@ -12,13 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Python helper for loading contrib ops and kernels."""
+"""Test utilities for tf.data functionality."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.util import loader
-from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import test
-_dataset_ops = loader.load_op_library(
- resource_loader.get_path_to_datafile("../../_dataset_ops.so"))
+
+class DatasetTestBase(test.TestCase):
+ """Base class for dataset tests."""
+
+ def assertSparseValuesEqual(self, a, b):
+ self.assertAllEqual(a.indices, b.indices)
+ self.assertAllEqual(a.values, b.values)
+ self.assertAllEqual(a.dense_shape, b.dense_shape)
diff --git a/tensorflow/python/data/kernel_tests/window_dataset_op_test.py b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
new file mode 100644
index 0000000000..9d06781094
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
@@ -0,0 +1,291 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ("1", 20, 14, 7, 1),
+ ("2", 20, 17, 9, 1),
+ ("3", 20, 14, 14, 1),
+ ("4", 20, 10, 14, 1),
+ ("5", 20, 14, 19, 1),
+ ("6", 20, 4, 1, 2),
+ ("7", 20, 2, 1, 6),
+ ("8", 20, 4, 7, 2),
+ ("9", 20, 2, 7, 6),
+ ("10", 1, 10, 4, 1),
+ ("11", 0, 10, 4, 1),
+ ("12", 20, 14, 7, 1, False),
+ ("13", 20, 17, 9, 1, False),
+ ("14", 20, 14, 14, 1, False),
+ ("15", 20, 10, 14, 1, False),
+ ("16", 20, 14, 19, 1, False),
+ ("17", 20, 4, 1, 2, False),
+ ("18", 20, 2, 1, 6, False),
+ ("19", 20, 4, 7, 2, False),
+ ("20", 20, 2, 7, 6, False),
+ ("21", 1, 10, 4, 1, False),
+ ("22", 0, 10, 4, 1, False),
+ )
+ def testWindowDataset(self, count, size, shift, stride, drop_remainder=True):
+ """Tests a dataset that slides a window its input elements."""
+ components = (np.arange(7),
+ np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
+ np.array(37.0) * np.arange(7))
+
+ count_t = array_ops.placeholder(dtypes.int64, shape=[])
+ size_t = array_ops.placeholder(dtypes.int64, shape=[])
+ shift_t = array_ops.placeholder(dtypes.int64, shape=[])
+ stride_t = array_ops.placeholder(dtypes.int64, shape=[])
+ drop_remainder_t = array_ops.placeholder(dtypes.bool, shape=[])
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ def _flat_map_fn(x, y, z):
+ return dataset_ops.Dataset.zip((x.batch(batch_size=size_t),
+ y.batch(batch_size=size_t),
+ z.batch(batch_size=size_t)))
+
+ iterator = dataset_ops.Dataset.from_tensor_slices(components).map(
+ _map_fn).repeat(count).window(
+ size=size_t,
+ shift=shift_t,
+ stride=stride_t,
+ drop_remainder=drop_remainder_t).flat_map(
+ _flat_map_fn).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ self.assertEqual([[None] + list(c.shape[1:]) for c in components],
+ [t.shape.as_list() for t in get_next])
+
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ count_t: count,
+ size_t: size,
+ shift_t: shift,
+ stride_t: stride,
+ drop_remainder_t: drop_remainder
+ })
+ num_full_batches = max(
+ 0, (count * 7 - ((size - 1) * stride + 1)) // shift + 1)
+ for i in range(num_full_batches):
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ for j in range(size):
+ self.assertAllEqual(component[(i * shift + j * stride) % 7]**2,
+ result_component[j])
+ if not drop_remainder:
+ num_partial_batches = (count * 7) // shift + (
+ (count * 7) % shift > 0) - num_full_batches
+ for i in range(num_partial_batches):
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ remaining = (count * 7) - ((num_full_batches + i) * shift)
+ num_elements = remaining // stride + ((remaining % stride) > 0)
+ for j in range(num_elements):
+ self.assertAllEqual(
+ component[((num_full_batches + i) * shift + j * stride) % 7]
+ **2, result_component[j])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ @parameterized.named_parameters(
+ ("1", 14, 0, 3, 1),
+ ("2", 14, 3, 0, 1),
+ ("3", 14, 3, 3, 0),
+ )
+ def testWindowDatasetInvalid(self, count, size, shift, stride):
+ count_t = array_ops.placeholder(dtypes.int64, shape=[])
+ size_t = array_ops.placeholder(dtypes.int64, shape=[])
+ shift_t = array_ops.placeholder(dtypes.int64, shape=[])
+ stride_t = array_ops.placeholder(dtypes.int64, shape=[])
+
+ iterator = dataset_ops.Dataset.range(10).map(lambda x: x).repeat(
+ count_t).window(
+ size=size_t, shift=shift_t,
+ stride=stride_t).flat_map(lambda x: x.batch(batch_size=size_t)
+ ).make_initializable_iterator()
+ init_op = iterator.initializer
+
+ with self.cached_session() as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(
+ init_op,
+ feed_dict={
+ count_t: count,
+ size_t: size,
+ shift_t: shift,
+ stride_t: stride
+ })
+
+ def testWindowSparse(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0]], values=(i * [1]), dense_shape=[1])
+
+ iterator = dataset_ops.Dataset.range(10).map(_sparse).window(
+ size=5, shift=3, drop_remainder=True).flat_map(
+ lambda x: x.batch(batch_size=5)).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ num_batches = (10 - 5) // 3 + 1
+ for i in range(num_batches):
+ actual = sess.run(get_next)
+ expected = sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
+ values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4],
+ dense_shape=[5, 1])
+ self.assertTrue(sparse_tensor.is_sparse(actual))
+ self.assertSparseValuesEqual(actual, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testWindowSparseWithDifferentDenseShapes(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=array_ops.expand_dims(
+ math_ops.range(i, dtype=dtypes.int64), 1),
+ values=array_ops.fill([math_ops.to_int32(i)], i),
+ dense_shape=[i])
+
+ iterator = dataset_ops.Dataset.range(10).map(_sparse).window(
+ size=5, shift=3, drop_remainder=True).flat_map(
+ lambda x: x.batch(batch_size=5)).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ num_batches = (10 - 5) // 3 + 1
+ for i in range(num_batches):
+ actual = sess.run(get_next)
+ expected_indices = []
+ expected_values = []
+ for j in range(5):
+ for k in range(i * 3 + j):
+ expected_indices.append([j, k])
+ expected_values.append(i * 3 + j)
+ expected = sparse_tensor.SparseTensorValue(
+ indices=expected_indices,
+ values=expected_values,
+ dense_shape=[5, i * 3 + 5 - 1])
+ self.assertTrue(sparse_tensor.is_sparse(actual))
+ self.assertSparseValuesEqual(actual, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testNestedWindowSparse(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0]], values=(i * [1]), dense_shape=[1])
+
+ iterator = dataset_ops.Dataset.range(10).map(_sparse).window(
+ size=4, shift=2,
+ drop_remainder=True).flat_map(lambda x: x.batch(batch_size=4)).window(
+ size=3, shift=1, drop_remainder=True).flat_map(
+ lambda x: x.batch(batch_size=3)).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ # Slide: 1st batch.
+ actual = sess.run(get_next)
+ expected = sparse_tensor.SparseTensorValue(
+ indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
+ [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
+ [2, 2, 0], [2, 3, 0]],
+ values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7],
+ dense_shape=[3, 4, 1])
+ self.assertTrue(sparse_tensor.is_sparse(actual))
+ self.assertSparseValuesEqual(actual, expected)
+ # Slide: 2nd batch.
+ actual = sess.run(get_next)
+ expected = sparse_tensor.SparseTensorValue(
+ indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
+ [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
+ [2, 2, 0], [2, 3, 0]],
+ values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9],
+ dense_shape=[3, 4, 1])
+ self.assertTrue(sparse_tensor.is_sparse(actual))
+ self.assertSparseValuesEqual(actual, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testWindowShapeError(self):
+
+ def generator():
+ yield [1.0, 2.0, 3.0]
+ yield [4.0, 5.0, 6.0]
+ yield [7.0, 8.0, 9.0, 10.0]
+
+ iterator = dataset_ops.Dataset.from_generator(
+ generator, dtypes.float32, output_shapes=[None]).window(
+ size=3, shift=1).flat_map(
+ lambda x: x.batch(batch_size=3)).make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r"Cannot batch tensors with different shapes in component 0. "
+ r"First element had shape \[3\] and element 2 had shape \[4\]."):
+ sess.run(next_element)
+
+ def testWindowIgnoreErrors(self):
+ input_values = np.float32([1., np.nan, 2., np.nan, 3.])
+ dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map(
+ lambda x: array_ops.check_numerics(x, "message")).window(
+ size=2, shift=2, stride=2,
+ drop_remainder=True).flat_map(lambda x: x.batch(batch_size=2))
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ self.assertAllEqual(np.float32([1., 2.]), sess.run(get_next))
+ self.assertAllEqual(np.float32([2., 3.]), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
index 3106effbd3..9d76387a34 100644
--- a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ZipDatasetTest(test.TestCase):
+class ZipDatasetTest(test_base.DatasetTestBase):
def testZipDataset(self):
component_placeholders = [
diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD
index 57517afae8..76bf2470b1 100644
--- a/tensorflow/python/data/ops/BUILD
+++ b/tensorflow/python/data/ops/BUILD
@@ -19,6 +19,7 @@ py_library(
"//tensorflow/python:math_ops",
"//tensorflow/python:random_seed",
"//tensorflow/python:script_ops",
+ "//tensorflow/python:smart_cond",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
@@ -63,6 +64,7 @@ py_library(
"//tensorflow/python/compat",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/util:structure",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
],
@@ -77,8 +79,23 @@ py_library(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/util:structure",
+ ],
+)
+
+py_library(
+ name = "multi_device_iterator_ops",
+ srcs = ["multi_device_iterator_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:functional_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
],
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index c985e00dd1..6bba72a8e9 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -80,6 +80,12 @@ class Dataset(object):
"""
raise NotImplementedError("Dataset._as_variant_tensor")
+ @abc.abstractmethod
+ def _inputs(self):
+ """Returns a list of the input datasets of the dataset."""
+
+ raise NotImplementedError("Dataset._inputs")
+
def make_initializable_iterator(self, shared_name=None):
"""Creates an `Iterator` for enumerating the elements of this dataset.
@@ -1009,6 +1015,23 @@ class Dataset(object):
def flat_map(self, map_func):
"""Maps `map_func` across this dataset and flattens the result.
+ Use `flat_map` if you want to make sure that the order of your dataset
+ stays the same. For example, to flatten a dataset of batches into a
+ dataset of their elements:
+
+ ```python
+ # NOTE: The following examples use `{ ... }` to represent the
+ # contents of a dataset. '[...]' represents a tensor.
+ a = {[1,2,3,4,5], [6,7,8,9], [10]}
+
+ a.flat_map(lambda x: Dataset.from_tensor_slices(x)) ==
+ {[1,2,3,4,5,6,7,8,9,10]}
+ ```
+
+ `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since
+ `flat_map` produces the same output as
+ `tf.data.Dataset.interleave(cycle_length=1)`
+
Args:
map_func: A function mapping a nested structure of tensors (having shapes
and types defined by `self.output_shapes` and `self.output_types`) to a
@@ -1043,7 +1066,7 @@ class Dataset(object):
elements are produced. `cycle_length` controls the number of input elements
that are processed concurrently. If you set `cycle_length` to 1, this
transformation will handle one input element at a time, and will produce
- identical results = to `tf.data.Dataset.flat_map`. In general,
+ identical results to `tf.data.Dataset.flat_map`. In general,
this transformation will apply `map_func` to `cycle_length` input elements,
open iterators on the returned `Dataset` objects, and cycle through them
producing `block_length` consecutive elements from each iterator, and
@@ -1115,7 +1138,7 @@ class Dataset(object):
return FilterDataset(self, predicate)
def apply(self, transformation_func):
- """Apply a transformation function to this dataset.
+ """Applies a transformation function to this dataset.
`apply` enables chaining of custom `Dataset` transformations, which are
represented as functions that take one `Dataset` argument and return a
@@ -1131,7 +1154,7 @@ class Dataset(object):
Args:
transformation_func: A function that takes one `Dataset` argument and
- returns a `Dataset`.
+ returns a `Dataset`.
Returns:
Dataset: The `Dataset` returned by applying `transformation_func` to this
@@ -1140,10 +1163,188 @@ class Dataset(object):
dataset = transformation_func(self)
if not isinstance(dataset, Dataset):
raise TypeError("`transformation_func` must return a Dataset.")
+ dataset._input_datasets = [self] # pylint: disable=protected-access
return dataset
+ def window(self, size, shift=None, stride=1, drop_remainder=False):
+ """Combines input elements into a dataset of windows.
+
+ Each window is a dataset itself and contains `size` elements (or
+ possibly fewer if there are not enough input elements to fill the window
+ and `drop_remainder` evaluates to false).
+
+ The `stride` argument determines the stride of the input elements,
+ and the `shift` argument determines the shift of the window.
+
+ For example:
+ - `tf.data.Dataset.range(7).window(2)` produces
+ `{{0, 1}, {2, 3}, {4, 5}, {6}}`
+ - `tf.data.Dataset.range(7).window(3, 2, 1, True)` produces
+ `{{0, 1, 2}, {2, 3, 4}, {4, 5, 6}}`
+ - `tf.data.Dataset.range(7).window(3, 1, 2, True)` produces
+ `{{0, 2, 4}, {1, 3, 5}, {2, 4, 6}}`
+
+ Args:
+ size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements
+ of the input dataset to combine into a window.
+ shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ forward shift of the sliding window in each iteration. Defaults to
+ `size`.
+ stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ stride of the input elements in the sliding window.
+ drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
+ whether a window should be dropped in case its size is smaller than
+ `window_size`.
+
+ Returns:
+ Dataset: A `Dataset` of windows, each of which is a nested `Dataset` with
+ the same structure as this dataset, but a finite subsequence of its
+ elements.
+ """
+ if shift is None:
+ shift = size
+ return WindowDataset(self, size, shift, stride, drop_remainder)
+
+ def reduce(self, initial_state, reduce_func):
+ """Reduces the input dataset to a single element.
+
+ The transformation calls `reduce_func` successively on every element of
+ the input dataset until the dataset is exhausted, aggregating information in
+ its internal state. The `initial_state` argument is used for the initial
+ state and the final state is returned as the result.
+
+ For example:
+ - `tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1)`
+ produces `5`
+ - `tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y)`
+ produces `10`
+
+ Args:
+ initial_state: A nested structure of tensors, representing the initial
+ state of the transformation.
+ reduce_func: A function that maps `(old_state, input_element)` to
+ `new_state`. It must take two arguments and return a nested structure
+ of tensors. The structure of `new_state` must match the structure of
+ `initial_state`.
+
+ Returns:
+ A nested structure of `tf.Tensor` objects, corresponding to the final
+ state of the transformation.
+
+ """
+
+ with ops.name_scope("initial_state"):
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
+ initial_state = nest.pack_sequence_as(initial_state, [
+ sparse_tensor_lib.SparseTensor.from_value(t)
+ if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
+ t, name="component_%d" % i)
+ for i, t in enumerate(nest.flatten(initial_state))
+ ])
+
+ # Compute initial values for the state classes, shapes and types based on
+ # the initial state.
+ state_classes = sparse.get_classes(initial_state)
+ state_shapes = nest.pack_sequence_as(
+ initial_state, [t.get_shape() for t in nest.flatten(initial_state)])
+ state_types = nest.pack_sequence_as(
+ initial_state, [t.dtype for t in nest.flatten(initial_state)])
+
+ # Iteratively rerun the reduce function until reaching a fixed point on
+ # `self._state_shapes`.
+ need_to_rerun = True
+ while need_to_rerun:
+
+ wrapped_func = StructuredFunctionWrapper(
+ reduce_func,
+ "reduce()",
+ input_classes=(state_classes, self.output_classes),
+ input_shapes=(state_shapes, self.output_shapes),
+ input_types=(state_types, self.output_types),
+ add_to_graph=False)
+
+ # Extract and validate class information from the returned values.
+ output_classes = wrapped_func.output_classes
+ for new_state_class, state_class in zip(
+ nest.flatten(output_classes), nest.flatten(state_classes)):
+ if not issubclass(new_state_class, state_class):
+ raise TypeError(
+ "The element classes for the new state must match the initial "
+ "state. Expected %s; got %s." % (state_classes,
+ wrapped_func.output_classes))
+
+ # Extract and validate type information from the returned values.
+ output_types = wrapped_func.output_types
+ for new_state_type, state_type in zip(
+ nest.flatten(output_types), nest.flatten(state_types)):
+ if new_state_type != state_type:
+ raise TypeError(
+ "The element types for the new state must match the initial "
+ "state. Expected %s; got %s." % (state_types,
+ wrapped_func.output_types))
+
+ # Extract shape information from the returned values.
+ output_shapes = wrapped_func.output_shapes
+ flat_state_shapes = nest.flatten(state_shapes)
+ flat_new_state_shapes = nest.flatten(output_shapes)
+ weakened_state_shapes = [
+ original.most_specific_compatible_shape(new)
+ for original, new in zip(flat_state_shapes, flat_new_state_shapes)
+ ]
+
+ need_to_rerun = False
+ for original_shape, weakened_shape in zip(flat_state_shapes,
+ weakened_state_shapes):
+ if original_shape.ndims is not None and (
+ weakened_shape.ndims is None or
+ original_shape.as_list() != weakened_shape.as_list()):
+ need_to_rerun = True
+ break
+
+ if need_to_rerun:
+ state_shapes = nest.pack_sequence_as(state_shapes,
+ weakened_state_shapes)
+
+ reduce_func = wrapped_func.function
+ reduce_func.add_to_graph(ops.get_default_graph())
+
+ return sparse.deserialize_sparse_tensors(
+ nest.pack_sequence_as(
+ output_types,
+ gen_dataset_ops.reduce_dataset(
+ self._as_variant_tensor(), # pylint: disable=protected-access
+ nest.flatten(sparse.serialize_sparse_tensors(initial_state)),
+ reduce_func.captured_inputs,
+ f=reduce_func,
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(output_shapes, output_classes)),
+ output_types=nest.flatten(
+ sparse.as_dense_types(output_types, output_classes)))),
+ output_types,
+ output_shapes,
+ output_classes)
+
+
+class DatasetSource(Dataset):
+ """Abstract class representing a dataset with no inputs."""
+
+ def _inputs(self):
+ return []
+
+
+class UnaryDataset(Dataset):
+ """Abstract class representing a dataset with one input."""
+
+ def __init__(self, input_dataset):
+ super(UnaryDataset, self).__init__()
+ self._input_dataset = input_dataset
+
+ def _inputs(self):
+ return [self._input_dataset]
+
-class TensorDataset(Dataset):
+class TensorDataset(DatasetSource):
"""A `Dataset` with a single element, viz. a nested structure of tensors."""
def __init__(self, tensors):
@@ -1183,7 +1384,7 @@ class TensorDataset(Dataset):
return self._output_types
-class TensorSliceDataset(Dataset):
+class TensorSliceDataset(DatasetSource):
"""A `Dataset` of slices from a nested structure of tensors."""
def __init__(self, tensors):
@@ -1227,7 +1428,7 @@ class TensorSliceDataset(Dataset):
return self._output_types
-class SparseTensorSliceDataset(Dataset):
+class SparseTensorSliceDataset(DatasetSource):
"""A `Dataset` that splits a rank-N `tf.SparseTensor` into its rows."""
def __init__(self, sparse_tensor):
@@ -1328,6 +1529,9 @@ class _VariantDataset(Dataset):
def _as_variant_tensor(self):
return self._dataset_variant
+ def _inputs(self):
+ return []
+
@property
def output_classes(self):
return self._structure.output_classes
@@ -1568,7 +1772,7 @@ def flat_structure(dataset):
}
-class _GeneratorDataset(Dataset):
+class _GeneratorDataset(DatasetSource):
"""A `Dataset` that generates elements by invoking a function."""
def __init__(self, init_args, init_func, next_func, finalize_func):
@@ -1669,6 +1873,9 @@ class ZipDataset(Dataset):
**flat_structure(self))
# pylint: enable=protected-access
+ def _inputs(self):
+ return nest.flatten(self._datasets)
+
@property
def output_classes(self):
return nest.pack_sequence_as(
@@ -1704,6 +1911,7 @@ class ConcatenateDataset(Dataset):
raise TypeError(
"Two datasets to concatenate have different classes %s and %s" %
(input_dataset.output_classes, dataset_to_concatenate.output_classes))
+ self._input_datasets = [input_dataset, dataset_to_concatenate]
def _as_variant_tensor(self):
# pylint: disable=protected-access
@@ -1713,6 +1921,9 @@ class ConcatenateDataset(Dataset):
**flat_structure(self))
# pylint: enable=protected-access
+ def _inputs(self):
+ return [self._input_dataset, self._dataset_to_concatenate]
+
@property
def output_classes(self):
return self._input_dataset.output_classes
@@ -1731,12 +1942,12 @@ class ConcatenateDataset(Dataset):
return self._input_dataset.output_types
-class RepeatDataset(Dataset):
+class RepeatDataset(UnaryDataset):
"""A `Dataset` that repeats its input several times."""
def __init__(self, input_dataset, count):
"""See `Dataset.repeat()` for details."""
- super(RepeatDataset, self).__init__()
+ super(RepeatDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if count is None:
self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
@@ -1763,7 +1974,7 @@ class RepeatDataset(Dataset):
return self._input_dataset.output_types
-class RangeDataset(Dataset):
+class RangeDataset(DatasetSource):
"""A `Dataset` of a step separated range of values."""
def __init__(self, *args):
@@ -1811,12 +2022,12 @@ class RangeDataset(Dataset):
return dtypes.int64
-class CacheDataset(Dataset):
+class CacheDataset(UnaryDataset):
"""A `Dataset` that caches elements of its input."""
def __init__(self, input_dataset, filename):
"""See `Dataset.cache()` for details."""
- super(CacheDataset, self).__init__()
+ super(CacheDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._filename = ops.convert_to_tensor(
filename, dtype=dtypes.string, name="filename")
@@ -1840,7 +2051,7 @@ class CacheDataset(Dataset):
return self._input_dataset.output_types
-class ShuffleDataset(Dataset):
+class ShuffleDataset(UnaryDataset):
"""A `Dataset` that randomly shuffles the elements of its input."""
def __init__(self,
@@ -1868,7 +2079,7 @@ class ShuffleDataset(Dataset):
Raises:
ValueError: if invalid arguments are provided.
"""
- super(ShuffleDataset, self).__init__()
+ super(ShuffleDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size")
@@ -1900,12 +2111,12 @@ class ShuffleDataset(Dataset):
return self._input_dataset.output_types
-class TakeDataset(Dataset):
+class TakeDataset(UnaryDataset):
"""A `Dataset` containing the first `count` elements from its input."""
def __init__(self, input_dataset, count):
"""See `Dataset.take()` for details."""
- super(TakeDataset, self).__init__()
+ super(TakeDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
@@ -1928,12 +2139,12 @@ class TakeDataset(Dataset):
return self._input_dataset.output_types
-class SkipDataset(Dataset):
+class SkipDataset(UnaryDataset):
"""A `Dataset` skipping the first `count` elements from its input."""
def __init__(self, input_dataset, count):
"""See `Dataset.skip()` for details."""
- super(SkipDataset, self).__init__()
+ super(SkipDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
@@ -1956,12 +2167,12 @@ class SkipDataset(Dataset):
return self._input_dataset.output_types
-class BatchDataset(Dataset):
+class BatchDataset(UnaryDataset):
"""A `Dataset` that batches contiguous elements from its input."""
def __init__(self, input_dataset, batch_size, drop_remainder):
"""See `Dataset.batch()` for details."""
- super(BatchDataset, self).__init__()
+ super(BatchDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
@@ -2110,13 +2321,13 @@ def _default_padding(input_dataset):
return nest.map_structure(make_zero, input_dataset.output_types)
-class PaddedBatchDataset(Dataset):
+class PaddedBatchDataset(UnaryDataset):
"""A `Dataset` that batches and pads contiguous elements from its input."""
def __init__(self, input_dataset, batch_size, padded_shapes, padding_values,
drop_remainder):
"""See `Dataset.batch()` for details."""
- super(PaddedBatchDataset, self).__init__()
+ super(PaddedBatchDataset, self).__init__(input_dataset)
if sparse.any_sparse(input_dataset.output_classes):
# TODO(b/63669786): support batching of sparse tensors
raise TypeError(
@@ -2216,12 +2427,12 @@ def _warn_if_collections(transformation_name):
% transformation_name)
-class MapDataset(Dataset):
+class MapDataset(UnaryDataset):
"""A `Dataset` that maps a function over elements in its input."""
def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
"""See `Dataset.map()` for details."""
- super(MapDataset, self).__init__()
+ super(MapDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._use_inter_op_parallelism = use_inter_op_parallelism
@@ -2282,12 +2493,12 @@ class ParallelMapDataset(MapDataset):
# pylint: enable=protected-access
-class FlatMapDataset(Dataset):
+class FlatMapDataset(UnaryDataset):
"""A `Dataset` that maps a function over its input and flattens the result."""
def __init__(self, input_dataset, map_func):
"""See `Dataset.flat_map()` for details."""
- super(FlatMapDataset, self).__init__()
+ super(FlatMapDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
wrapped_func = StructuredFunctionWrapper(
@@ -2378,12 +2589,12 @@ class ParallelInterleaveDataset(FlatMapDataset):
return "Dataset.interleave()"
-class FilterDataset(Dataset):
+class FilterDataset(UnaryDataset):
"""A `Dataset` that filters its input according to a predicate function."""
def __init__(self, input_dataset, predicate):
"""See `Dataset.filter()` for details."""
- super(FilterDataset, self).__init__()
+ super(FilterDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
wrapped_func = StructuredFunctionWrapper(
predicate, "Dataset.filter()", input_dataset)
@@ -2413,12 +2624,12 @@ class FilterDataset(Dataset):
return self._input_dataset.output_types
-class PrefetchDataset(Dataset):
+class PrefetchDataset(UnaryDataset):
"""A `Dataset` that asynchronously prefetches its input."""
def __init__(self, input_dataset, buffer_size):
"""See `Dataset.prefetch()` for details."""
- super(PrefetchDataset, self).__init__()
+ super(PrefetchDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if buffer_size is None:
buffer_size = -1 # This is the sentinel for auto-tuning.
@@ -2442,3 +2653,53 @@ class PrefetchDataset(Dataset):
@property
def output_types(self):
return self._input_dataset.output_types
+
+
+class WindowDataset(UnaryDataset):
+ """A dataset that creates window datasets from the input elements."""
+
+ def __init__(self, input_dataset, size, shift, stride, drop_remainder):
+ """See `window_dataset()` for more details."""
+ super(WindowDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size")
+ self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift")
+ self._stride = ops.convert_to_tensor(
+ stride, dtype=dtypes.int64, name="stride")
+ self._drop_remainder = ops.convert_to_tensor(
+ drop_remainder, dtype=dtypes.bool, name="drop_remainder")
+ self._output_classes = nest.pack_sequence_as(
+ input_dataset.output_classes,
+ [
+ _NestedDatasetComponent( # pylint: disable=protected-access
+ output_classes=output_class,
+ output_shapes=output_shape,
+ output_types=output_type)
+ for output_class, output_shape, output_type in zip(
+ nest.flatten(input_dataset.output_classes),
+ nest.flatten(input_dataset.output_shapes),
+ nest.flatten(input_dataset.output_types))
+ ])
+ self._output_shapes = self._output_classes
+ self._output_types = self._output_classes
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.window_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._size,
+ self._shift,
+ self._stride,
+ self._drop_remainder,
+ **flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index 8f8e026df9..cae00cdbfc 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -24,6 +24,7 @@ from tensorflow.python.compat import compat
from tensorflow.python.data.ops import optional_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
+from tensorflow.python.data.util import structure
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -85,10 +86,10 @@ class Iterator(checkpointable.CheckpointableBase):
initializer: A `tf.Operation` that should be run to initialize this
iterator.
output_types: A nested structure of `tf.DType` objects corresponding to
- each component of an element of this dataset.
+ each component of an element of this iterator.
output_shapes: A nested structure of `tf.TensorShape` objects
- corresponding to each component of an element of this dataset.
- output_classes: A nested structure of Python `type` object corresponding
+ corresponding to each component of an element of this iterator.
+ output_classes: A nested structure of Python `type` objects corresponding
to each component of an element of this iterator.
"""
self._iterator_resource = iterator_resource
@@ -670,6 +671,6 @@ def get_next_as_optional(iterator):
output_shapes=nest.flatten(
sparse.as_dense_shapes(iterator.output_shapes,
iterator.output_classes))),
- output_shapes=iterator.output_shapes,
- output_types=iterator.output_types,
- output_classes=iterator.output_classes)
+ structure.Structure._from_legacy_structure(iterator.output_types,
+ iterator.output_shapes,
+ iterator.output_classes))
diff --git a/tensorflow/python/data/ops/multi_device_iterator_ops.py b/tensorflow/python/data/ops/multi_device_iterator_ops.py
new file mode 100644
index 0000000000..b7d3aac206
--- /dev/null
+++ b/tensorflow/python/data/ops/multi_device_iterator_ops.py
@@ -0,0 +1,231 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrapper for prefetching_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gen_dataset_ops
+
+
+class _PerDeviceGenerator(dataset_ops.Dataset):
+ """A `dummy` generator dataset."""
+
+ def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
+ source_device, target_device, output_shapes, output_types,
+ output_classes):
+ self._target_device = target_device
+ self._output_types = output_types
+ self._output_shapes = output_shapes
+ self._output_classes = output_classes
+ self._flat_output_shapes = nest.flatten(
+ sparse.as_dense_shapes(self._output_shapes, self._output_classes))
+ self._flat_output_types = nest.flatten(
+ sparse.as_dense_types(self._output_types, self._output_classes))
+
+ multi_device_iterator_string_handle = (
+ gen_dataset_ops.multi_device_iterator_to_string_handle(
+ multi_device_iterator_resource))
+
+ @function.Defun()
+ def _init_func():
+ return multi_device_iterator_string_handle
+
+ @function.Defun()
+ def _remote_init_func():
+ return functional_ops.remote_call(
+ target=source_device,
+ args=_init_func.captured_inputs,
+ Tout=[dtypes.string],
+ f=_init_func)
+
+ self._init_func = _remote_init_func
+ self._init_captured_args = _remote_init_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _next_func(string_handle):
+ multi_device_iterator = (
+ gen_dataset_ops.multi_device_iterator_from_string_handle(
+ string_handle=string_handle,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes))
+ return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
+ multi_device_iterator=multi_device_iterator,
+ shard_num=shard_num,
+ incarnation_id=incarnation_id,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+
+ @function.Defun(dtypes.string)
+ def _remote_next_func(string_handle):
+ return functional_ops.remote_call(
+ target=source_device,
+ args=[string_handle] + _next_func.captured_inputs,
+ Tout=self._flat_output_types,
+ f=_next_func)
+
+ self._next_func = _remote_next_func
+ self._next_captured_args = _remote_next_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _finalize_func(unused_string_handle):
+ return array_ops.constant(0, dtypes.int64)
+
+ @function.Defun(dtypes.string)
+ def _remote_finalize_func(string_handle):
+ return functional_ops.remote_call(
+ target=source_device,
+ args=[string_handle] + _finalize_func.captured_inputs,
+ Tout=[dtypes.int64],
+ f=_finalize_func)
+
+ self._finalize_func = _remote_finalize_func
+ self._finalize_captured_args = _remote_finalize_func.captured_inputs
+
+ def _as_variant_tensor(self):
+ with ops.device(self._target_device):
+ return gen_dataset_ops.generator_dataset(
+ self._init_captured_args,
+ self._next_captured_args,
+ self._finalize_captured_args,
+ init_func=self._init_func,
+ next_func=self._next_func,
+ finalize_func=self._finalize_func,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+
+ def _inputs(self):
+ # TODO(b/116506223): Determine which datasets should be used as inputs here.
+ return []
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+
+class MultiDeviceIterator(object):
+ """An iterator over multiple devices.
+
+ @compatibility(eager)
+ MultiDeviceIterator isn't currently supported in Eager mode but support is
+ coming soon.
+ @end_compatibility
+ """
+
+ def __init__(self,
+ dataset,
+ devices,
+ max_buffer_size=1,
+ prefetch_buffer_size=1,
+ source_device="/cpu:0"):
+ """Constructs a MultiDeviceIterator.
+
+ Args:
+ dataset: The input dataset to be iterated over.
+ devices: The list of devices to fetch data to.
+ max_buffer_size: Maximum size of the host side per device buffer to keep.
+ prefetch_buffer_size: if > 1, then we setup a buffer on each device
+ to prefetch into.
+ source_device: The host device to place the `dataset` on.
+
+ Raises:
+ RuntimeError: If run in Eager mode.
+ """
+ if context.executing_eagerly():
+ # TODO(rohanj): Fix this. Tracking bug: b/116467184
+ raise RuntimeError("MultiDeviceIterator is not currently supported in "
+ "Eager mode.")
+ self._dataset = dataset
+ self._devices = devices
+ self._source_device = source_device
+ self._source_device_tensor = ops.convert_to_tensor(source_device)
+
+ self._flat_output_shapes = nest.flatten(
+ sparse.as_dense_shapes(self._dataset.output_shapes,
+ self._dataset.output_classes))
+ self._flat_output_types = nest.flatten(
+ sparse.as_dense_types(self._dataset.output_types,
+ self._dataset.output_classes))
+
+ # Create the MultiDeviceIterator.
+ with ops.device(self._source_device):
+ self._multi_device_iterator_resource = (
+ gen_dataset_ops.multi_device_iterator(
+ devices=self._devices,
+ shared_name="",
+ container="",
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes))
+
+ # The incarnation ID is used to ensure consistency between the per-device
+ # iterators and the multi-device iterator.
+ self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
+ self._dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._multi_device_iterator_resource,
+ max_buffer_size=max_buffer_size)
+
+ # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
+ # initialize the device side of the pipeline. This would allow the
+ # MultiDeviceIterator to choose, for example, to move some transformations
+ # into the device side from its input. It might be useful in rewriting.
+ # Create the per device iterators.
+ self._device_iterators = []
+ i = 0
+ for device in self._devices:
+ ds = _PerDeviceGenerator(
+ i, self._multi_device_iterator_resource, self._incarnation_id,
+ self._source_device_tensor, device, self._dataset.output_shapes,
+ self._dataset.output_types, self._dataset.output_classes)
+ if prefetch_buffer_size > 0:
+ ds = ds.prefetch(prefetch_buffer_size)
+ with ops.device(device):
+ self._device_iterators.append(ds.make_initializable_iterator())
+ i += 1
+
+ device_iterator_initializers = [
+ iterator.initializer for iterator in self._device_iterators
+ ]
+ self._initializer = control_flow_ops.group(*device_iterator_initializers)
+
+ def get_next(self):
+ result = []
+ i = 0
+ for device in self._devices:
+ with ops.device(device):
+ result.append(self._device_iterators[i].get_next())
+ i += 1
+ return result
+
+ @property
+ def initializer(self):
+ return self._initializer
diff --git a/tensorflow/python/data/ops/optional_ops.py b/tensorflow/python/data/ops/optional_ops.py
index b75b98dc72..3bbebd7878 100644
--- a/tensorflow/python/data/ops/optional_ops.py
+++ b/tensorflow/python/data/ops/optional_ops.py
@@ -19,11 +19,9 @@ from __future__ import print_function
import abc
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
+from tensorflow.python.data.util import structure
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_dataset_ops
@@ -67,36 +65,14 @@ class Optional(object):
raise NotImplementedError("Optional.get_value()")
@abc.abstractproperty
- def output_classes(self):
- """Returns the class of each component of this optional.
-
- The expected values are `tf.Tensor` and `tf.SparseTensor`.
-
- Returns:
- A nested structure of Python `type` objects corresponding to each
- component of this optional.
- """
- raise NotImplementedError("Optional.output_classes")
-
- @abc.abstractproperty
- def output_shapes(self):
- """Returns the shape of each component of this optional.
-
- Returns:
- A nested structure of `tf.TensorShape` objects corresponding to each
- component of this optional.
- """
- raise NotImplementedError("Optional.output_shapes")
-
- @abc.abstractproperty
- def output_types(self):
- """Returns the type of each component of this optional.
+ def value_structure(self):
+ """The structure of the components of this optional.
Returns:
- A nested structure of `tf.DType` objects corresponding to each component
- of this optional.
+ A `Structure` object representing the structure of the components of this
+ optional.
"""
- raise NotImplementedError("Optional.output_types")
+ raise NotImplementedError("Optional.value_structure")
@staticmethod
def from_value(value):
@@ -108,48 +84,30 @@ class Optional(object):
Returns:
An `Optional` that wraps `value`.
"""
- # TODO(b/110122868): Consolidate this destructuring logic with the
- # similar code in `Dataset.from_tensors()`.
with ops.name_scope("optional") as scope:
with ops.name_scope("value"):
- value = nest.pack_sequence_as(value, [
- sparse_tensor_lib.SparseTensor.from_value(t)
- if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
- t, name="component_%d" % i)
- for i, t in enumerate(nest.flatten(value))
- ])
-
- encoded_value = nest.flatten(sparse.serialize_sparse_tensors(value))
- output_classes = sparse.get_classes(value)
- output_shapes = nest.pack_sequence_as(
- value, [t.get_shape() for t in nest.flatten(value)])
- output_types = nest.pack_sequence_as(
- value, [t.dtype for t in nest.flatten(value)])
+ value_structure = structure.Structure.from_value(value)
+ encoded_value = value_structure._to_tensor_list(value) # pylint: disable=protected-access
return _OptionalImpl(
gen_dataset_ops.optional_from_value(encoded_value, name=scope),
- output_shapes, output_types, output_classes)
+ value_structure)
@staticmethod
- def none_from_structure(output_shapes, output_types, output_classes):
+ def none_from_structure(value_structure):
"""Returns an `Optional` that has no value.
- NOTE: This method takes arguments that define the structure of the value
+ NOTE: This method takes an argument that defines the structure of the value
that would be contained in the returned `Optional` if it had a value.
Args:
- output_shapes: A nested structure of `tf.TensorShape` objects
- corresponding to each component of this optional.
- output_types: A nested structure of `tf.DType` objects corresponding to
- each component of this optional.
- output_classes: A nested structure of Python `type` objects corresponding
- to each component of this optional.
+ value_structure: A `Structure` object representing the structure of the
+ components of this optional.
Returns:
An `Optional` that has no value.
"""
- return _OptionalImpl(gen_dataset_ops.optional_none(), output_shapes,
- output_types, output_classes)
+ return _OptionalImpl(gen_dataset_ops.optional_none(), value_structure)
class _OptionalImpl(Optional):
@@ -159,20 +117,9 @@ class _OptionalImpl(Optional):
`Optional.__init__()` in the public API.
"""
- def __init__(self, variant_tensor, output_shapes, output_types,
- output_classes):
- # TODO(b/110122868): Consolidate the structure validation logic with the
- # similar logic in `Iterator.from_structure()` and
- # `Dataset.from_generator()`.
- output_types = nest.map_structure(dtypes.as_dtype, output_types)
- output_shapes = nest.map_structure_up_to(
- output_types, tensor_shape.as_shape, output_shapes)
- nest.assert_same_structure(output_types, output_shapes)
- nest.assert_same_structure(output_types, output_classes)
+ def __init__(self, variant_tensor, value_structure):
self._variant_tensor = variant_tensor
- self._output_shapes = output_shapes
- self._output_types = output_types
- self._output_classes = output_classes
+ self._value_structure = value_structure
def has_value(self, name=None):
return gen_dataset_ops.optional_has_value(self._variant_tensor, name=name)
@@ -182,28 +129,55 @@ class _OptionalImpl(Optional):
# in `Iterator.get_next()` and `StructuredFunctionWrapper`.
with ops.name_scope(name, "OptionalGetValue",
[self._variant_tensor]) as scope:
- return sparse.deserialize_sparse_tensors(
- nest.pack_sequence_as(
- self._output_types,
- gen_dataset_ops.optional_get_value(
- self._variant_tensor,
- name=scope,
- output_types=nest.flatten(
- sparse.as_dense_types(self._output_types,
- self._output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self._output_shapes,
- self._output_classes)))),
- self._output_types, self._output_shapes, self._output_classes)
+ # pylint: disable=protected-access
+ return self._value_structure._from_tensor_list(
+ gen_dataset_ops.optional_get_value(
+ self._variant_tensor,
+ name=scope,
+ output_types=self._value_structure._flat_types,
+ output_shapes=self._value_structure._flat_shapes))
@property
- def output_classes(self):
- return self._output_classes
+ def value_structure(self):
+ return self._value_structure
+
+
+class OptionalStructure(structure.Structure):
+ """Represents an optional potentially containing a structured value."""
+
+ def __init__(self, value_structure):
+ self._value_structure = value_structure
@property
- def output_shapes(self):
- return self._output_shapes
+ def _flat_shapes(self):
+ return [tensor_shape.scalar()]
@property
- def output_types(self):
- return self._output_types
+ def _flat_types(self):
+ return [dtypes.variant]
+
+ def is_compatible_with(self, other):
+ # pylint: disable=protected-access
+ return (isinstance(other, OptionalStructure) and
+ self._value_structure.is_compatible_with(other._value_structure))
+
+ def _to_tensor_list(self, value):
+ return [value._variant_tensor] # pylint: disable=protected-access
+
+ def _from_tensor_list(self, flat_value):
+ if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
+ not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "OptionalStructure corresponds to a single tf.variant scalar.")
+ # pylint: disable=protected-access
+ return _OptionalImpl(flat_value[0], self._value_structure)
+
+ @staticmethod
+ def from_value(value):
+ return OptionalStructure(value.value_structure)
+
+
+# pylint: disable=protected-access
+structure.Structure._register_custom_converter(Optional,
+ OptionalStructure.from_value)
+# pylint: enable=protected-access
diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py
index 066e09969c..b0f26631f9 100644
--- a/tensorflow/python/data/ops/readers.py
+++ b/tensorflow/python/data/ops/readers.py
@@ -61,6 +61,9 @@ class TextLineDataset(dataset_ops.Dataset):
return gen_dataset_ops.text_line_dataset(
self._filenames, self._compression_type, self._buffer_size)
+ def _inputs(self):
+ return []
+
@property
def output_classes(self):
return ops.Tensor
@@ -105,6 +108,9 @@ class _TFRecordDataset(dataset_ops.Dataset):
return gen_dataset_ops.tf_record_dataset(
self._filenames, self._compression_type, self._buffer_size)
+ def _inputs(self):
+ return []
+
@property
def output_classes(self):
return ops.Tensor
@@ -224,6 +230,9 @@ class TFRecordDataset(dataset_ops.Dataset):
def _as_variant_tensor(self):
return self._impl._as_variant_tensor() # pylint: disable=protected-access
+ def _inputs(self):
+ return self._impl._inputs() # pylint: disable=protected-access
+
@property
def output_classes(self):
return self._impl.output_classes
@@ -278,6 +287,9 @@ class FixedLengthRecordDataset(dataset_ops.Dataset):
self._filenames, self._header_bytes, self._record_bytes,
self._footer_bytes, self._buffer_size)
+ def _inputs(self):
+ return []
+
@property
def output_classes(self):
return ops.Tensor
diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py
index c5764b8dfe..a90ca258c0 100644
--- a/tensorflow/python/data/util/structure.py
+++ b/tensorflow/python/data/util/structure.py
@@ -28,6 +28,9 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import sparse_ops
+_STRUCTURE_CONVERSION_FUNCTION_REGISTRY = {}
+
+
class Structure(object):
"""Represents structural information, such as type and shape, about a value.
@@ -64,12 +67,10 @@ class Structure(object):
raise NotImplementedError("Structure._flat_shapes")
@abc.abstractmethod
- def is_compatible_with(self, value):
- """Returns `True` if `value` is compatible with this structure.
+ def is_compatible_with(self, other):
+ """Returns `True` if `other` is compatible with this structure.
- A value `value` is compatible with a structure `s` if
- `Structure.from_value(value)` would return a structure `t` that is a
- "subtype" of `s`. A structure `t` is a "subtype" of `s` if:
+ A structure `t` is a "subtype" of `s` if:
* `s` and `t` are instances of the same `Structure` subclass.
* The nested structures (if any) of `s` and `t` are the same, according to
@@ -83,10 +84,10 @@ class Structure(object):
`tf.TensorShape.is_compatible_with`.
Args:
- value: A potentially structured value.
+ other: A `Structure`.
Returns:
- `True` if `value` matches this structure, otherwise `False`.
+ `True` if `other` is a subtype of this structure, otherwise `False`.
"""
raise NotImplementedError("Structure.is_compatible_with()")
@@ -98,7 +99,7 @@ class Structure(object):
`self._flat_types` to represent structured values in lower level APIs
(such as plain TensorFlow operations) that do not understand structure.
- Requires: `self.is_compatible_with(value)`.
+ Requires: `self.is_compatible_with(Structure.from_value(value))`.
Args:
value: A value with compatible structure.
@@ -137,9 +138,8 @@ class Structure(object):
TypeError: If a structure cannot be built for `value`, because its type
or one of its component types is not supported.
"""
-
- # TODO(b/110122868): Add support for custom types, Dataset, and Optional
- # to this method.
+ # TODO(b/110122868): Add support for custom types and Dataset to this
+ # method.
if isinstance(
value,
(sparse_tensor_lib.SparseTensor, sparse_tensor_lib.SparseTensorValue)):
@@ -147,12 +147,76 @@ class Structure(object):
elif isinstance(value, (tuple, dict)):
return NestedStructure.from_value(value)
else:
+ for converter_type, converter_fn in (
+ _STRUCTURE_CONVERSION_FUNCTION_REGISTRY.items()):
+ if isinstance(value, converter_type):
+ return converter_fn(value)
try:
tensor = ops.convert_to_tensor(value)
except (ValueError, TypeError):
raise TypeError("Could not build a structure for %r" % value)
return TensorStructure.from_value(tensor)
+ @staticmethod
+ def _from_legacy_structure(output_types, output_shapes, output_classes):
+ """Returns a `Structure` that represents the given legacy structure.
+
+ This method provides a way to convert from the existing `Dataset` and
+ `Iterator` structure-related properties to a `Structure` object.
+
+ TODO(b/110122868): Remove this method once `Structure` is used throughout
+ `tf.data`.
+
+ Args:
+ output_types: A nested structure of `tf.DType` objects corresponding to
+ each component of a structured value.
+ output_shapes: A nested structure of `tf.TensorShape` objects
+ corresponding to each component a structured value.
+ output_classes: A nested structure of Python `type` objects corresponding
+ to each component of a structured value.
+
+ Returns:
+ A `Structure`.
+
+ Raises:
+ TypeError: If a structure cannot be built the arguments, because one of
+ the component classes in `output_classes` is not supported.
+ """
+ flat_types = nest.flatten(output_types)
+ flat_shapes = nest.flatten(output_shapes)
+ flat_classes = nest.flatten(output_classes)
+ flat_ret = []
+ for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes,
+ flat_classes):
+ if issubclass(flat_class, sparse_tensor_lib.SparseTensor):
+ flat_ret.append(SparseTensorStructure(flat_type, flat_shape))
+ elif issubclass(flat_class, ops.Tensor):
+ flat_ret.append(TensorStructure(flat_type, flat_shape))
+ else:
+ # NOTE(mrry): Since legacy structures produced by iterators only
+ # comprise Tensors, SparseTensors, and nests, we do not need to support
+ # all structure types here.
+ raise TypeError(
+ "Could not build a structure for output class %r" % flat_type)
+
+ ret = nest.pack_sequence_as(output_classes, flat_ret)
+ if isinstance(ret, Structure):
+ return ret
+ else:
+ return NestedStructure(ret)
+
+ @staticmethod
+ def _register_custom_converter(type_object, converter_fn):
+ """Registers `converter_fn` for converting values of the given type.
+
+ Args:
+ type_object: A Python `type` object representing the type of values
+ accepted by `converter_fn`.
+ converter_fn: A function that takes one argument (an instance of the
+ type represented by `type_object`) and returns a `Structure`.
+ """
+ _STRUCTURE_CONVERSION_FUNCTION_REGISTRY[type_object] = converter_fn
+
# NOTE(mrry): The following classes make extensive use of non-public methods of
# their base class, so we disable the protected-access lint warning once here.
@@ -179,16 +243,21 @@ class NestedStructure(Structure):
def _flat_types(self):
return self._flat_types_list
- def is_compatible_with(self, value):
+ def is_compatible_with(self, other):
+ if not isinstance(other, NestedStructure):
+ return False
try:
- nest.assert_shallow_structure(self._nested_structure, value)
+ # pylint: disable=protected-access
+ nest.assert_same_structure(self._nested_structure,
+ other._nested_structure)
except (ValueError, TypeError):
return False
return all(
- s.is_compatible_with(v) for s, v in zip(
+ substructure.is_compatible_with(other_substructure)
+ for substructure, other_substructure in zip(
nest.flatten(self._nested_structure),
- nest.flatten_up_to(self._nested_structure, value)))
+ nest.flatten(other._nested_structure)))
def _to_tensor_list(self, value):
ret = []
@@ -201,7 +270,7 @@ class NestedStructure(Structure):
for sub_value, structure in zip(flat_value,
nest.flatten(self._nested_structure)):
- if not structure.is_compatible_with(sub_value):
+ if not structure.is_compatible_with(Structure.from_value(sub_value)):
raise ValueError("Component value %r is not compatible with the nested "
"structure %r." % (sub_value, structure))
ret.extend(structure._to_tensor_list(sub_value))
@@ -242,17 +311,13 @@ class TensorStructure(Structure):
def _flat_types(self):
return [self._dtype]
- def is_compatible_with(self, value):
- try:
- value = ops.convert_to_tensor(value, dtype=self._dtype)
- except (ValueError, TypeError):
- return False
-
- return (self._dtype.is_compatible_with(value.dtype) and
- self._shape.is_compatible_with(value.shape))
+ def is_compatible_with(self, other):
+ return (isinstance(other, TensorStructure) and
+ self._dtype.is_compatible_with(other._dtype) and
+ self._shape.is_compatible_with(other._shape))
def _to_tensor_list(self, value):
- if not self.is_compatible_with(value):
+ if not self.is_compatible_with(Structure.from_value(value)):
raise ValueError("Value %r is not convertible to a tensor with dtype %s "
"and shape %s." % (value, self._dtype, self._shape))
return [value]
@@ -260,7 +325,7 @@ class TensorStructure(Structure):
def _from_tensor_list(self, flat_value):
if len(flat_value) != 1:
raise ValueError("TensorStructure corresponds to a single tf.Tensor.")
- if not self.is_compatible_with(flat_value[0]):
+ if not self.is_compatible_with(Structure.from_value(flat_value[0])):
raise ValueError("Cannot convert %r to a tensor with dtype %s and shape "
"%s." % (flat_value[0], self._dtype, self._shape))
return flat_value[0]
@@ -285,16 +350,10 @@ class SparseTensorStructure(Structure):
def _flat_types(self):
return [dtypes.variant]
- def is_compatible_with(self, value):
- try:
- value = sparse_tensor_lib.SparseTensor.from_value(value)
- except TypeError:
- return False
- return (isinstance(value, (sparse_tensor_lib.SparseTensor,
- sparse_tensor_lib.SparseTensorValue)) and
- self._dtype.is_compatible_with(value.dtype) and
- self._dense_shape.is_compatible_with(
- tensor_util.constant_value_as_shape(value.dense_shape)))
+ def is_compatible_with(self, other):
+ return (isinstance(other, SparseTensorStructure) and
+ self._dtype.is_compatible_with(other._dtype) and
+ self._dense_shape.is_compatible_with(other._dense_shape))
def _to_tensor_list(self, value):
return [sparse_ops.serialize_sparse(value, out_type=dtypes.variant)]
diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py
index d0c7df67ae..2982763181 100644
--- a/tensorflow/python/data/util/structure_test.py
+++ b/tensorflow/python/data/util/structure_test.py
@@ -25,7 +25,9 @@ from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -106,13 +108,17 @@ class StructureTest(test.TestCase, parameterized.TestCase):
indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
}, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
)
- def testIsCompatibleWith(self, original_value, compatible_values,
- incompatible_values):
+ def testIsCompatibleWithStructure(self, original_value, compatible_values,
+ incompatible_values):
s = structure.Structure.from_value(original_value)
for compatible_value in compatible_values:
- self.assertTrue(s.is_compatible_with(compatible_value))
+ self.assertTrue(
+ s.is_compatible_with(
+ structure.Structure.from_value(compatible_value)))
for incompatible_value in incompatible_values:
- self.assertFalse(s.is_compatible_with(incompatible_value))
+ self.assertFalse(
+ s.is_compatible_with(
+ structure.Structure.from_value(incompatible_value)))
# NOTE(mrry): The arguments must be lifted into lambdas because otherwise they
# will be executed before the (eager- or graph-mode) test environment has been
@@ -322,6 +328,28 @@ class StructureTest(test.TestCase, parameterized.TestCase):
ValueError, "Expected 3 flat values in NestedStructure but got 2."):
s_2._from_tensor_list(flat_s_1)
+ @parameterized.named_parameters(
+ ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor,
+ structure.TensorStructure(dtypes.float32, [])),
+ ("SparseTensor", dtypes.int32, tensor_shape.matrix(2, 2),
+ sparse_tensor.SparseTensor,
+ structure.SparseTensorStructure(dtypes.int32, [2, 2])),
+ ("Nest",
+ {"a": dtypes.float32, "b": (dtypes.int32, dtypes.string)},
+ {"a": tensor_shape.scalar(),
+ "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar())},
+ {"a": ops.Tensor, "b": (sparse_tensor.SparseTensor, ops.Tensor)},
+ structure.NestedStructure({
+ "a": structure.TensorStructure(dtypes.float32, []),
+ "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
+ structure.TensorStructure(dtypes.string, []))})),
+ )
+ def testFromLegacyStructure(self, output_types, output_shapes, output_classes,
+ expected_structure):
+ actual_structure = structure.Structure._from_legacy_structure(
+ output_types, output_shapes, output_classes)
+ self.assertTrue(expected_structure.is_compatible_with(actual_structure))
+ self.assertTrue(actual_structure.is_compatible_with(expected_structure))
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 849d165bfa..e84482d2b2 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -18,6 +18,7 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow:tensorflow.bzl", "py_binary")
load("//tensorflow:tensorflow.bzl", "if_not_windows")
py_library(
diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py
index 55231954d1..f197a9e4dc 100644
--- a/tensorflow/python/debug/cli/analyzer_cli_test.py
+++ b/tensorflow/python/debug/cli/analyzer_cli_test.py
@@ -57,7 +57,8 @@ def no_rewrite_session_config():
disable_model_pruning=True,
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
- dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
+ dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)
@@ -598,11 +599,11 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
v_name = "simple_mul_add/v"
u_init = constant_op.constant(u_init_val, shape=[2, 2], name="u_init")
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
cls._u_line_number = line_number_above()
v_init = constant_op.constant(v_init_val, shape=[2, 1], name="v_init")
- v = variables.Variable(v_init, name=v_name)
+ v = variables.VariableV1(v_init, name=v_name)
cls._v_line_number = line_number_above()
w = math_ops.matmul(u, v, name="simple_mul_add/matmul")
@@ -611,7 +612,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
x = math_ops.add(w, w, name="simple_mul_add/add")
cls._x_line_number = line_number_above()
- a = variables.Variable([1, 3, 3, 7], name="a")
+ a = variables.VariableV1([1, 3, 3, 7], name="a")
u.initializer.run()
v.initializer.run()
@@ -1370,7 +1371,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
# Verify the annotation of the line that creates u.
index = self._findSourceLine(out, self._u_line_number)
self.assertEqual(
- ["L%d u = variables.Variable(u_init, name=u_name)" %
+ ["L%d u = variables.VariableV1(u_init, name=u_name)" %
self._u_line_number,
" simple_mul_add/u",
" simple_mul_add/u/Assign",
@@ -1387,7 +1388,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
# Verify the annotation of the line that creates v.
index = self._findSourceLine(out, self._v_line_number)
self.assertEqual(
- ["L%d v = variables.Variable(v_init, name=v_name)" %
+ ["L%d v = variables.VariableV1(v_init, name=v_name)" %
self._v_line_number,
" simple_mul_add/v"],
out.lines[index : index + 2])
@@ -1424,7 +1425,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
# Verify the annotation of the line that creates u.
index = self._findSourceLine(out, self._u_line_number)
self.assertEqual(
- ["L%d u = variables.Variable(u_init, name=u_name)" %
+ ["L%d u = variables.VariableV1(u_init, name=u_name)" %
self._u_line_number,
" simple_mul_add/u/read:0",
" simple_mul_add/u:0"],
@@ -1446,7 +1447,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
index = self._findSourceLine(out, self._u_line_number)
self.assertEqual(
- ["L%d u = variables.Variable(u_init, name=u_name)" %
+ ["L%d u = variables.VariableV1(u_init, name=u_name)" %
self._u_line_number,
" simple_mul_add/u",
" simple_mul_add/u/Assign",
@@ -1469,7 +1470,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
index = self._findSourceLine(out, self._u_line_number)
self.assertEqual(
- ["L%d u = variables.Variable(u_init, name=u_name)" %
+ ["L%d u = variables.VariableV1(u_init, name=u_name)" %
self._u_line_number,
" simple_mul_add/u",
" (... Omitted 2 of 3 op(s) ...) +5"],
@@ -1579,7 +1580,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
"""List an input tree containing tensors from non-:0 output slot."""
with session.Session(config=no_rewrite_session_config()) as sess:
- x = variables.Variable([1, 3, 3, 7], name="x")
+ x = variables.VariableV1([1, 3, 3, 7], name="x")
_, idx = array_ops.unique(x, name="x_unique")
idx_times_two = math_ops.multiply(idx, 2, name="idx_times_two")
sess.run(x.initializer)
@@ -1683,7 +1684,7 @@ class AnalyzerCLIControlDepTest(test_util.TensorFlowTestCase):
with session.Session(config=no_rewrite_session_config()) as sess:
x_init_val = np.array([5.0, 3.0])
x_init = constant_op.constant(x_init_val, shape=[2])
- x = variables.Variable(x_init, name="control_deps/x")
+ x = variables.VariableV1(x_init, name="control_deps/x")
y = math_ops.add(x, x, name="control_deps/y")
y = control_flow_ops.with_dependencies(
diff --git a/tensorflow/python/debug/cli/stepper_cli_test.py b/tensorflow/python/debug/cli/stepper_cli_test.py
index ee8cabca0d..7b8a42c253 100644
--- a/tensorflow/python/debug/cli/stepper_cli_test.py
+++ b/tensorflow/python/debug/cli/stepper_cli_test.py
@@ -132,8 +132,8 @@ def _parse_updated(lines):
class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
def setUp(self):
- self.a = variables.Variable(10.0, name="a")
- self.b = variables.Variable(20.0, name="b")
+ self.a = variables.VariableV1(10.0, name="a")
+ self.b = variables.VariableV1(20.0, name="b")
self.c = math_ops.add(self.a, self.b, name="c") # Should be 30.0.
self.d = math_ops.subtract(self.a, self.c, name="d") # Should be -20.0.
diff --git a/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py b/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py
index 676097fde9..1f67f8a0d4 100644
--- a/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py
+++ b/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py
@@ -45,6 +45,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
def _no_rewrite_session_config(self):
rewriter_config = rewriter_config_pb2.RewriterConfig(
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF,
min_graph_nodes=-1)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)
@@ -156,7 +157,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
sess, cond, expected_output=21.0)
def testReconstructGraphWithWhileLoop(self):
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
loop_body = lambda i: math_ops.add(i, 2)
loop_cond = lambda i: math_ops.less(i, 16)
i = constant_op.constant(10, name="i")
diff --git a/tensorflow/python/debug/lib/debug_utils_test.py b/tensorflow/python/debug/lib/debug_utils_test.py
index 5b1875e092..23ab98444c 100644
--- a/tensorflow/python/debug/lib/debug_utils_test.py
+++ b/tensorflow/python/debug/lib/debug_utils_test.py
@@ -46,8 +46,8 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
cls._b_init = constant_op.constant(
cls._b_init_val, shape=[2, 1], name="b_init")
- cls._a = variables.Variable(cls._a_init, name="a1")
- cls._b = variables.Variable(cls._b_init, name="b")
+ cls._a = variables.VariableV1(cls._a_init, name="a1")
+ cls._b = variables.VariableV1(cls._b_init, name="b")
cls._c = constant_op.constant(cls._c_val, shape=[2, 1], name="c")
# Matrix product of a and b.
diff --git a/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py b/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py
index 46a7be5808..74498c8ea3 100644
--- a/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py
+++ b/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py
@@ -118,8 +118,8 @@ class DistributedSessionDebugTest(test_util.TensorFlowTestCase):
"""
with ops.Graph().as_default() as graph:
with ops.device("/job:worker/task:0/cpu:0"):
- self.a = variables.Variable(10.0, name="a")
- self.b = variables.Variable(100.0, name="b")
+ self.a = variables.VariableV1(10.0, name="a")
+ self.b = variables.VariableV1(100.0, name="b")
self.inc_a = state_ops.assign_add(self.a, 2.0, name="inc_a")
self.dec_b = state_ops.assign_add(self.b, -5.0, name="dec_b")
self.p = math_ops.multiply(self.inc_a, self.dec_b, name="p")
diff --git a/tensorflow/python/debug/lib/grpc_large_data_test.py b/tensorflow/python/debug/lib/grpc_large_data_test.py
index 5bc477a9ba..ccc21bcf94 100644
--- a/tensorflow/python/debug/lib/grpc_large_data_test.py
+++ b/tensorflow/python/debug/lib/grpc_large_data_test.py
@@ -61,7 +61,7 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
with self.test_session(
use_gpu=True,
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- u = variables.Variable(42.0, name="original_u")
+ u = variables.VariableV1(42.0, name="original_u")
for _ in xrange(50 * 1000):
u = array_ops.identity(u)
sess.run(variables.global_variables_initializer())
@@ -94,7 +94,7 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
u_init = constant_op.constant(
u_init_val_array, dtype=dtypes.float32, name="u_init")
- u = variables.Variable(u_init, name="u")
+ u = variables.VariableV1(u_init, name="u")
def watch_fn(fetches, feeds):
del fetches, feeds # Unused by this watch_fn.
@@ -117,7 +117,7 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
b"", b"spam", b"A" * 2500 * 1024, b"B" * 2500 * 1024, b"egg", b""]
u_init = constant_op.constant(
u_init_val, dtype=dtypes.string, name="u_init")
- u = variables.Variable(u_init, name="u")
+ u = variables.VariableV1(u_init, name="u")
def watch_fn(fetches, feeds):
del fetches, feeds
@@ -146,7 +146,7 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
u_init = constant_op.constant(
u_init_val_array, dtype=dtypes.string, name="u_init")
- u = variables.Variable(u_init, name="u")
+ u = variables.VariableV1(u_init, name="u")
def watch_fn(fetches, feeds):
del fetches, feeds
@@ -167,7 +167,7 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
config=session_debug_testlib.no_rewrite_session_config()) as sess:
u_init = constant_op.constant(
[], dtype=dtypes.float32, shape=[0], name="u_init")
- u = variables.Variable(u_init, name="u")
+ u = variables.VariableV1(u_init, name="u")
def watch_fn(fetches, feeds):
del fetches, feeds
@@ -189,7 +189,7 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
config=session_debug_testlib.no_rewrite_session_config()) as sess:
u_init = constant_op.constant(
[], dtype=dtypes.string, shape=[0], name="u_init")
- u = variables.Variable(u_init, name="u")
+ u = variables.VariableV1(u_init, name="u")
def watch_fn(fetches, feeds):
del fetches, feeds
diff --git a/tensorflow/python/debug/lib/session_debug_file_test.py b/tensorflow/python/debug/lib/session_debug_file_test.py
index ba0f15b4e2..1874160dd6 100644
--- a/tensorflow/python/debug/lib/session_debug_file_test.py
+++ b/tensorflow/python/debug/lib/session_debug_file_test.py
@@ -58,9 +58,9 @@ class SessionDebugFileTest(session_debug_testlib.SessionDebugTestBase):
v_name = "diff_Watch/v"
u_init = constant_op.constant(u_init_val, shape=[2, 2])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
v_init = constant_op.constant(v_init_val, shape=[2, 1])
- v = variables.Variable(v_init, name=v_name)
+ v = variables.VariableV1(v_init, name=v_name)
w = math_ops.matmul(u, v, name="diff_Watch/matmul")
diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py
index ff49b69547..bfc9a3a382 100644
--- a/tensorflow/python/debug/lib/session_debug_grpc_test.py
+++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py
@@ -148,8 +148,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
sess, "localhost:%d" % self._server_port, watch_fn="foo")
def testGrpcDebugWrapperSessionWithoutWatchFnWorks(self):
- u = variables.Variable(2.1, name="u")
- v = variables.Variable(20.0, name="v")
+ u = variables.VariableV1(2.1, name="u")
+ v = variables.VariableV1(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess = session.Session(
@@ -175,8 +175,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
del feeds, fetch_keys
return ["DebugIdentity", "DebugNumericSummary"], r".*/read", None
- u = variables.Variable(2.1, name="u")
- v = variables.Variable(20.0, name="v")
+ u = variables.VariableV1(2.1, name="u")
+ v = variables.VariableV1(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess = session.Session(
@@ -209,8 +209,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
op_type_regex_whitelist=None,
tolerate_debug_op_creation_failures=True)
- u = variables.Variable(2.1, name="u")
- v = variables.Variable(20.0, name="v")
+ u = variables.VariableV1(2.1, name="u")
+ v = variables.VariableV1(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess = session.Session(
@@ -241,8 +241,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0]))
def testTensorBoardDebugHookWorks(self):
- u = variables.Variable(2.1, name="u")
- v = variables.Variable(20.0, name="v")
+ u = variables.VariableV1(2.1, name="u")
+ v = variables.VariableV1(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess = session.Session(
@@ -286,8 +286,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
self._server.query_source_file_line(__file__, 1)
def testTensorBoardDebugHookDisablingTracebackSourceCodeSendingWorks(self):
- u = variables.Variable(2.1, name="u")
- v = variables.Variable(20.0, name="v")
+ u = variables.VariableV1(2.1, name="u")
+ v = variables.VariableV1(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess = session.Session(
@@ -381,8 +381,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenDebugNodes(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- v_1 = variables.Variable(50.0, name="v_1")
- v_2 = variables.Variable(-50.0, name="v_1")
+ v_1 = variables.VariableV1(50.0, name="v_1")
+ v_2 = variables.VariableV1(-50.0, name="v_1")
delta_1 = constant_op.constant(5.0, name="delta_1")
delta_2 = constant_op.constant(-5.0, name="delta_2")
inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
@@ -451,8 +451,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- v_1 = variables.Variable(50.0, name="v_1")
- v_2 = variables.Variable(-50.0, name="v_1")
+ v_1 = variables.VariableV1(50.0, name="v_1")
+ v_2 = variables.VariableV1(-50.0, name="v_1")
# These two nodes have names that match those in the
# toggle_watch_on_core_metadata argument used when calling
# start_server_on_separate_thread().
@@ -491,7 +491,7 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenServers(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- v = variables.Variable(50.0, name="v")
+ v = variables.VariableV1(50.0, name="v")
delta = constant_op.constant(5.0, name="delta")
inc_v = state_ops.assign_add(v, delta, name="inc_v")
@@ -534,8 +534,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
def testToggleBreakpointsWorks(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- v_1 = variables.Variable(50.0, name="v_1")
- v_2 = variables.Variable(-50.0, name="v_2")
+ v_1 = variables.VariableV1(50.0, name="v_1")
+ v_2 = variables.VariableV1(-50.0, name="v_2")
delta_1 = constant_op.constant(5.0, name="delta_1")
delta_2 = constant_op.constant(-5.0, name="delta_2")
inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
@@ -592,8 +592,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
def testTensorBoardDebuggerWrapperToggleBreakpointsWorks(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- v_1 = variables.Variable(50.0, name="v_1")
- v_2 = variables.Variable(-50.0, name="v_2")
+ v_1 = variables.VariableV1(50.0, name="v_1")
+ v_2 = variables.VariableV1(-50.0, name="v_2")
delta_1 = constant_op.constant(5.0, name="delta_1")
delta_2 = constant_op.constant(-5.0, name="delta_2")
inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
@@ -665,8 +665,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
def testTensorBoardDebuggerWrapperDisablingTracebackSourceSendingWorks(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- v_1 = variables.Variable(50.0, name="v_1")
- v_2 = variables.Variable(-50.0, name="v_2")
+ v_1 = variables.VariableV1(50.0, name="v_1")
+ v_2 = variables.VariableV1(-50.0, name="v_2")
delta_1 = constant_op.constant(5.0, name="delta_1")
delta_2 = constant_op.constant(-5.0, name="delta_2")
inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
@@ -699,7 +699,7 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
def testGetGrpcDebugWatchesReturnsCorrectAnswer(self):
with session.Session() as sess:
- v = variables.Variable(50.0, name="v")
+ v = variables.VariableV1(50.0, name="v")
delta = constant_op.constant(5.0, name="delta")
inc_v = state_ops.assign_add(v, delta, name="inc_v")
@@ -741,9 +741,9 @@ class DelayedDebugServerTest(test_util.TensorFlowTestCase):
debug_server) = grpc_debug_test_server.start_server_on_separate_thread(
server_start_delay_sec=2.0, dump_to_filesystem=False)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a_init = constant_op.constant(42.0, name="a_init")
- a = variables.Variable(a_init, name="a")
+ a = variables.VariableV1(a_init, name="a")
def watch_fn(fetches, feeds):
del fetches, feeds
diff --git a/tensorflow/python/debug/lib/session_debug_testlib.py b/tensorflow/python/debug/lib/session_debug_testlib.py
index 070d9c4cd7..25ef91b575 100644
--- a/tensorflow/python/debug/lib/session_debug_testlib.py
+++ b/tensorflow/python/debug/lib/session_debug_testlib.py
@@ -70,7 +70,7 @@ class _RNNCellForTest(rnn_cell_impl.RNNCell):
def __init__(self, input_output_size, state_size):
self._input_output_size = input_output_size
self._state_size = state_size
- self._w = variables.Variable(1.0, dtype=dtypes.float32, name="w")
+ self._w = variables.VariableV1(1.0, dtype=dtypes.float32, name="w")
@property
def output_size(self):
@@ -182,9 +182,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
w_name = "w"
u_init = constant_op.constant(u_init_val, shape=[2, 2])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
v_init = constant_op.constant(v_init_val, shape=[2, 1])
- v = variables.Variable(v_init, name=v_name)
+ v = variables.VariableV1(v_init, name=v_name)
w = math_ops.matmul(u, v, name=w_name)
@@ -221,8 +221,8 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testCopyNodesHaveCorrectDebugOpsAndURLsAttributeValues(self):
with session.Session() as sess:
- u = variables.Variable(2.1, name="u")
- v = variables.Variable(20.0, name="v")
+ u = variables.VariableV1(2.1, name="u")
+ v = variables.VariableV1(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess.run(variables.global_variables_initializer())
@@ -324,8 +324,8 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
str1_name = "str1"
str2_name = "str2"
- str1 = variables.Variable(str1_init, name=str1_name)
- str2 = variables.Variable(str2_init, name=str2_name)
+ str1 = variables.VariableV1(str1_init, name=str1_name)
+ str2 = variables.VariableV1(str2_init, name=str2_name)
# Concatenate str1 and str2
str_concat = math_ops.add(str1, str2, name="str_concat")
@@ -387,9 +387,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
s_name = "%s/s" % op_namespace
u_init = constant_op.constant(u_init_val, shape=[2, 2])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
s_init = constant_op.constant(s_init_val)
- s = variables.Variable(s_init, name=s_name)
+ s = variables.VariableV1(s_init, name=s_name)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_urls = self._debug_urls()
@@ -439,7 +439,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
u_init_val = np.array(11.0)
u_init = constant_op.constant(u_init_val)
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
# "v" is the increment.
v_name = "testDumpToFileWhileLoop/v"
@@ -447,7 +447,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
v_init_val = np.array(2.0)
v_init = constant_op.constant(v_init_val)
- v = variables.Variable(v_init, name=v_name)
+ v = variables.VariableV1(v_init, name=v_name)
u.initializer.run()
v.initializer.run()
@@ -605,8 +605,8 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugCondWatchingWholeGraphWorks(self):
with session.Session() as sess:
- x = variables.Variable(10.0, name="x")
- y = variables.Variable(20.0, name="y")
+ x = variables.VariableV1(10.0, name="x")
+ y = variables.VariableV1(20.0, name="y")
cond = control_flow_ops.cond(
x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1))
@@ -628,9 +628,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
z_name = "testFindNodesWithBadTensorValues/z"
u_init = constant_op.constant([2.0, 4.0])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
v_init = constant_op.constant([2.0, 1.0])
- v = variables.Variable(v_init, name=v_name)
+ v = variables.VariableV1(v_init, name=v_name)
# Expected output: [0.0, 3.0]
w = math_ops.subtract(u, v, name=w_name)
@@ -679,9 +679,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
z_name = "testFindInfOrNanWithOpNameExclusion/z"
u_init = constant_op.constant([2.0, 4.0])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
v_init = constant_op.constant([2.0, 1.0])
- v = variables.Variable(v_init, name=v_name)
+ v = variables.VariableV1(v_init, name=v_name)
# Expected output: [0.0, 3.0]
w = math_ops.subtract(u, v, name=w_name)
@@ -725,7 +725,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
w_name = "testDumpGraphStructureLookup/w"
u_init = constant_op.constant([2.0, 4.0])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
v = math_ops.add(u, u, name=v_name)
w = math_ops.add(v, v, name=w_name)
@@ -859,9 +859,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testGraphPathFindingOnControlEdgesWorks(self):
with session.Session(config=no_rewrite_session_config()) as sess:
- v1 = variables.Variable(1.0, name="v1")
- v2 = variables.Variable(2.0, name="v2")
- v3 = variables.Variable(3.0, name="v3")
+ v1 = variables.VariableV1(1.0, name="v1")
+ v2 = variables.VariableV1(2.0, name="v2")
+ v3 = variables.VariableV1(3.0, name="v3")
a = math_ops.add(v1, v2, name="a")
with ops.control_dependencies([a]):
c = math_ops.subtract(v3, v3, name="c")
@@ -875,8 +875,8 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testGraphPathFindingReverseRefEdgeWorks(self):
with session.Session(config=no_rewrite_session_config()) as sess:
- v = variables.Variable(10.0, name="v")
- delta = variables.Variable(1.0, name="delta")
+ v = variables.VariableV1(10.0, name="v")
+ delta = variables.VariableV1(1.0, name="delta")
inc_v = state_ops.assign_add(v, delta, name="inc_v")
sess.run(variables.global_variables_initializer())
@@ -894,7 +894,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
w_name = "testDumpCausalityCheck/w"
u_init = constant_op.constant([2.0, 4.0])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
v = math_ops.add(u, u, name=v_name)
w = math_ops.add(v, v, name=w_name)
@@ -980,7 +980,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
w_name = "oneOfTwoSlots/w"
y_name = "oneOfTwoSlots/y"
- x = variables.Variable([1, 3, 3, 7], dtype=dtypes.int32, name=x_name)
+ x = variables.VariableV1([1, 3, 3, 7], dtype=dtypes.int32, name=x_name)
sess.run(x.initializer)
unique_x, indices, _ = array_ops.unique_with_counts(x, name=u_name)
@@ -1039,9 +1039,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
with session.Session(config=no_rewrite_session_config()) as sess:
u_init = constant_op.constant(10.0)
- u = variables.Variable(u_init, name="gdo/u")
+ u = variables.VariableV1(u_init, name="gdo/u")
v_init = constant_op.constant(20.0)
- v = variables.Variable(v_init, name="gdo/v")
+ v = variables.VariableV1(v_init, name="gdo/v")
w = math_ops.multiply(u, v, name="gdo/w")
# gdo stands for GradientDescentOptimizer.
@@ -1085,7 +1085,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
with session.Session() as sess:
x_init = constant_op.constant([2, 2, 3, 5, 5])
- x = variables.Variable(x_init, name="unconnected/x")
+ x = variables.VariableV1(x_init, name="unconnected/x")
# The UniqueOp (tf.unique) has two output slots. Use only slot 0 in the
# graph. Let the debugger watch the unused slot 1.
@@ -1225,14 +1225,14 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugNumericSummaryOnInitializedTensorGivesCorrectResult(self):
with session.Session(config=no_rewrite_session_config()) as sess:
- a = variables.Variable(
+ a = variables.VariableV1(
[
np.nan, np.nan, 0.0, 0.0, 0.0, -1.0, -3.0, 3.0, 7.0, -np.inf,
-np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.nan, np.nan
],
dtype=np.float32,
name="numeric_summary/a")
- b = variables.Variable(
+ b = variables.VariableV1(
[0.0] * 18, dtype=np.float32, name="numeric_summary/b")
c = math_ops.add(a, b, name="numeric_summary/c")
@@ -1249,7 +1249,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugNumericSummaryOnUninitializedTensorGivesCorrectResult(self):
with session.Session() as sess:
- a = variables.Variable(
+ a = variables.VariableV1(
[42], dtype=np.float32, name="numeric_summary_uninit/a")
_, dump = self._debug_run_and_get_dump(
@@ -1275,9 +1275,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugNumericSummaryFailureIsToleratedWhenOrdered(self):
with session.Session() as sess:
- a = variables.Variable("1", name="a")
- b = variables.Variable("3", name="b")
- c = variables.Variable("2", name="c")
+ a = variables.VariableV1("1", name="a")
+ b = variables.VariableV1("3", name="b")
+ c = variables.VariableV1("2", name="c")
d = math_ops.add(a, b, name="d")
e = math_ops.add(d, c, name="e")
@@ -1313,9 +1313,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugNumericSummaryInvalidAttributesStringAreCaught(self):
with session.Session(config=no_rewrite_session_config()) as sess:
- a = variables.Variable(10.0, name="a")
- b = variables.Variable(0.0, name="b")
- c = variables.Variable(0.0, name="c")
+ a = variables.VariableV1(10.0, name="a")
+ b = variables.VariableV1(0.0, name="b")
+ c = variables.VariableV1(0.0, name="c")
x = math_ops.divide(a, b, name="x")
y = math_ops.multiply(x, c, name="y")
@@ -1361,9 +1361,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugNumericSummaryMuteOnHealthyMutesOnlyHealthyTensorDumps(self):
with session.Session(config=no_rewrite_session_config()) as sess:
- a = variables.Variable(10.0, name="a")
- b = variables.Variable(0.0, name="b")
- c = variables.Variable(0.0, name="c")
+ a = variables.VariableV1(10.0, name="a")
+ b = variables.VariableV1(0.0, name="b")
+ c = variables.VariableV1(0.0, name="c")
x = math_ops.divide(a, b, name="x")
y = math_ops.multiply(x, c, name="y")
@@ -1396,8 +1396,8 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugNumericSummaryMuteOnHealthyAndCustomBoundsWork(self):
with session.Session() as sess:
- a = variables.Variable([10.0, 10.0], name="a")
- b = variables.Variable([10.0, 2.0], name="b")
+ a = variables.VariableV1([10.0, 10.0], name="a")
+ b = variables.VariableV1([10.0, 2.0], name="b")
x = math_ops.add(a, b, name="x") # [20.0, 12.0]
y = math_ops.divide(x, b, name="y") # [2.0, 6.0]
@@ -1436,9 +1436,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testLookUpNodePythonTracebackWorks(self):
with session.Session() as sess:
u_init = constant_op.constant(10.0)
- u = variables.Variable(u_init, name="traceback/u")
+ u = variables.VariableV1(u_init, name="traceback/u")
v_init = constant_op.constant(20.0)
- v = variables.Variable(v_init, name="traceback/v")
+ v = variables.VariableV1(v_init, name="traceback/v")
w = math_ops.multiply(u, v, name="traceback/w")
@@ -1487,7 +1487,7 @@ class DebugConcurrentRunCallsTest(test_util.TensorFlowTestCase):
self.skipTest("No testing concurrent runs on a single GPU.")
with session.Session() as sess:
- v = variables.Variable(30.0, name="v")
+ v = variables.VariableV1(30.0, name="v")
constants = []
for i in xrange(self._num_concurrent_runs):
constants.append(constant_op.constant(1.0, name="c%d" % i))
diff --git a/tensorflow/python/debug/lib/stepper_test.py b/tensorflow/python/debug/lib/stepper_test.py
index 9a3d0efabf..3839c67198 100644
--- a/tensorflow/python/debug/lib/stepper_test.py
+++ b/tensorflow/python/debug/lib/stepper_test.py
@@ -36,8 +36,8 @@ from tensorflow.python.training import gradient_descent
class StepperTest(test_util.TensorFlowTestCase):
def setUp(self):
- self.a = variables.Variable(2.0, name="a")
- self.b = variables.Variable(3.0, name="b")
+ self.a = variables.VariableV1(2.0, name="a")
+ self.b = variables.VariableV1(3.0, name="b")
self.c = math_ops.multiply(self.a, self.b, name="c") # Should be 6.0.
self.d = math_ops.multiply(self.a, self.a, name="d") # Should be 4.0.
@@ -49,7 +49,7 @@ class StepperTest(test_util.TensorFlowTestCase):
# The there nodes x, y and z form a graph with "cross-links" in. I.e., x
# and y are both direct inputs to z, but x is also a direct input to y.
- self.x = variables.Variable(2.0, name="x") # Should be 2.0
+ self.x = variables.VariableV1(2.0, name="x") # Should be 2.0
self.y = math_ops.negative(self.x, name="y") # Should be -2.0.
self.z = math_ops.multiply(self.x, self.y, name="z") # Should be -4.0.
@@ -580,7 +580,7 @@ class StepperTestWithPlaceHolders(test_util.TensorFlowTestCase):
class StepperAssignAddTest(test_util.TensorFlowTestCase):
def setUp(self):
- self.v = variables.Variable(10.0, name="v")
+ self.v = variables.VariableV1(10.0, name="v")
self.p = math_ops.add(self.v, self.v, name="p")
self.q = math_ops.multiply(self.p, self.p, name="q")
self.delta = constant_op.constant(2.0, name="delta")
@@ -711,9 +711,9 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
Construct a backward graph using the GradientDescentOptimizer.
"""
- self.a = variables.Variable(1.0, name="a")
- self.b = variables.Variable(2.0, name="b")
- self.c = variables.Variable(4.0, name="c")
+ self.a = variables.VariableV1(1.0, name="a")
+ self.b = variables.VariableV1(2.0, name="b")
+ self.c = variables.VariableV1(4.0, name="c")
self.d = math_ops.multiply(self.a, self.b, name="d")
self.e = math_ops.multiply(self.b, self.c, name="e")
self.f = math_ops.multiply(self.d, self.e, name="f")
diff --git a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
index 254201c393..11011a5c13 100644
--- a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
+++ b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
@@ -46,7 +46,7 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
def setUp(self):
self.session_root = tempfile.mkdtemp()
- self.v = variables.Variable(10.0, dtype=dtypes.float32, name="v")
+ self.v = variables.VariableV1(10.0, dtype=dtypes.float32, name="v")
self.delta = constant_op.constant(1.0, dtype=dtypes.float32, name="delta")
self.eta = constant_op.constant(-1.4, dtype=dtypes.float32, name="eta")
self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v")
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
index 05c9eaa4d2..149a7497df 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
@@ -132,8 +132,8 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
def setUp(self):
self._tmp_dir = tempfile.mktemp()
- self.v = variables.Variable(10.0, name="v")
- self.w = variables.Variable(21.0, name="w")
+ self.v = variables.VariableV1(10.0, name="v")
+ self.w = variables.VariableV1(21.0, name="w")
self.delta = constant_op.constant(1.0, name="delta")
self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v")
@@ -358,7 +358,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
def testDebuggingMakeCallableTensorRunnerWorks(self):
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
- v = variables.Variable(42)
+ v = variables.VariableV1(42)
tensor_runner = wrapped_sess.make_callable(v)
self.sess.run(v.initializer)
@@ -382,7 +382,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
def testDebuggingMakeCallableOperationRunnerWorks(self):
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
- v = variables.Variable(10.0)
+ v = variables.VariableV1(10.0)
inc_v = state_ops.assign_add(v, 1.0)
op_runner = wrapped_sess.make_callable(inc_v.op)
self.sess.run(v.initializer)
@@ -403,7 +403,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))
def testDebuggingMakeCallableFromOptionsWithZeroFeedWorks(self):
- variable_1 = variables.Variable(
+ variable_1 = variables.VariableV1(
10.5, dtype=dtypes.float32, name="variable_1")
a = math_ops.add(variable_1, variable_1, "callable_a")
math_ops.add(a, a, "callable_b")
@@ -480,7 +480,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
self.assertItemsEqual(["callable_a", "callable_b"], node_names)
def testDebugMakeCallableFromOptionsWithCustomOptionsAndMetadataWorks(self):
- variable_1 = variables.Variable(
+ variable_1 = variables.VariableV1(
10.5, dtype=dtypes.float32, name="variable_1")
a = math_ops.add(variable_1, variable_1, "callable_a")
math_ops.add(a, a, "callable_b")
@@ -528,7 +528,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
def testRuntimeErrorBeforeGraphExecutionIsRaised(self):
# Use an impossible device name to cause an error before graph execution.
with ops.device("/device:GPU:1337"):
- w = variables.Variable([1.0] * 10, name="w")
+ w = variables.VariableV1([1.0] * 10, name="w")
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["run"]], self.sess, dump_root=self._tmp_dir)
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index bd3562f1ff..b9b77d4a5b 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -126,7 +126,7 @@ class _WorkerContext(object):
replicated training.
task_id: an integer indicating id of the corresponding task. It can be
None if it is local training or in-graph replicated training.
- session_config: an optional @{tf.ConfigProto} object.
+ session_config: an optional `tf.ConfigProto` object.
rpc_layer: optional string specifying the RPC protocol for communication
with worker masters. If None or empty, hosts in the `cluster_spec` will
be used directly.
@@ -685,7 +685,7 @@ def run_distribute_coordinator(worker_fn,
in a cluster. If not set or empty, fall back to local training.
task_type: the current task type, optional if this is a client.
task_id: the current task id, optional if this is a client.
- session_config: an optional @{tf.ConfigProto} object which will be passed
+ session_config: an optional `tf.ConfigProto` object which will be passed
to `strategy`'s `configure` method and used to create a session.
rpc_layer: optional string, the protocol for RPC, e.g. "grpc".
diff --git a/tensorflow/python/distribute/estimator_training.py b/tensorflow/python/distribute/estimator_training.py
index e17a598123..0289689134 100644
--- a/tensorflow/python/distribute/estimator_training.py
+++ b/tensorflow/python/distribute/estimator_training.py
@@ -62,7 +62,7 @@ def _get_global_id(cluster_spec, task_type, task_id, chief_task_type):
# Sort task names in cluster by "chief"/"master", "evaluator", "worker"
# and "ps". More details can be found at the documentation of
- # @{tf.estimator.RunConfig.global_id_in_cluster}.
+ # `tf.estimator.RunConfig.global_id_in_cluster`.
task_type_ordered_list = []
if chief_task_type in cluster_spec.jobs:
task_type_ordered_list = [chief_task_type]
@@ -182,6 +182,7 @@ def should_run_distribute_coordinator(config):
# pylint: disable=protected-access
if (not hasattr(config, '_distribute_coordinator_mode') or
config._distribute_coordinator_mode is None):
+ logging.info('Not using Distribute Coordinator.')
return False
if (not isinstance(config._distribute_coordinator_mode, six.string_types) or
config._distribute_coordinator_mode not in [
@@ -221,15 +222,28 @@ def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls):
local_estimator = copy.deepcopy(estimator)
# pylint: disable=protected-access
local_estimator._config._train_distribute = strategy
- _init_run_config_from_worker_context(
- local_estimator._config, dc_context.get_current_worker_context())
+ context = dc_context.get_current_worker_context()
+ _init_run_config_from_worker_context(local_estimator._config, context)
+ logging.info('Updated config: %s', str(vars(local_estimator._config)))
local_estimator._train_distribution = strategy
# pylint: enable=protected-access
+ # In the standalone client, we don't need to run hooks on all threads
+ # because logging hooks on all threads may be too much on the screen; also
+ # tensor passed to one hook can only be fetched with the graph where the
+ # tensor is defined. Other hooks such as checkpointing hooks will added by
+ # MonitoredTrainingSession.
+ # TODO(yuefengz): Is there a hook that does need to run on all threads in
+ # standalone client mode?
+ if (run_config._distribute_coordinator_mode == # pylint: disable=protected-access
+ dc.CoordinatorMode.INDEPENDENT_WORKER or context.is_chief):
+ hooks = list(train_spec.hooks)
+ else:
+ hooks = []
local_estimator.train(
input_fn=train_spec.input_fn,
max_steps=train_spec.max_steps,
- hooks=list(train_spec.hooks))
+ hooks=hooks)
def _eval_fn(strategy):
"""Function for evaluator task."""
@@ -238,6 +252,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls):
local_estimator._config._eval_distribute = strategy
_init_run_config_from_worker_context(
local_estimator._config, dc_context.get_current_worker_context())
+ logging.info('Updated config: %s', str(vars(local_estimator._config)))
local_estimator._eval_distribution = strategy
executor = executor_cls(local_estimator, train_spec, eval_spec)
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index c1bc27d443..d3d997e6df 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -17,7 +17,10 @@ cc_library(
"pywrap_tensor.h",
"pywrap_tfe.h",
],
- visibility = ["//tensorflow:internal"],
+ visibility = [
+ "//learning/deepmind/courier:__pkg__",
+ "//tensorflow:internal",
+ ],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
@@ -34,6 +37,7 @@ cc_library(
"//tensorflow/python:safe_ptr",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
+ "@com_google_absl//absl/types:variant",
],
)
@@ -45,6 +49,7 @@ py_library(
":backprop",
":context",
":core",
+ ":def_function",
":execute",
":function",
":graph_only_ops",
@@ -146,6 +151,7 @@ cuda_py_test(
"//tensorflow/python:clip_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:layers",
+ "//tensorflow/python:list_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:resource_variable_ops",
],
@@ -378,3 +384,30 @@ cuda_py_test(
"optonly", # The test is too slow in non-opt mode
],
)
+
+py_library(
+ name = "def_function",
+ srcs = ["def_function.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":context",
+ ":function",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/training/checkpointable:base",
+ ],
+)
+
+py_test(
+ name = "def_function_test",
+ srcs = ["def_function_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":def_function",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_ops",
+ ],
+)
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index be392c7a0f..78f3198011 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -120,27 +120,6 @@ def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
pywrap_tensorflow.TFE_Py_RegisterGradientFunction(_gradient_function)
-_tracing = False
-
-
-# TODO(agarwal): use an automatic mechanism for handling None arguments to
-# gradient functions.
-# Some gradient functions can accept None arguments for gradients. The following
-# maps the operation name to the indices at which the corresponding gradient
-# function can accept None values.
-# e.g. FusedBatchNorm outputs 5 values and hence receives 5 gradient values
-# during backprop. However the gradient function uses only the first of those
-# values and ignores the rest. The entry, "FusedBatchNorm": [1, 2, 3, 4],
-# indicates that only the gradient corresponding to index 0 is used, and the
-# gradient values at indices 1-4 are ignored (and hence can be None). The
-# backprop algorithm can then leverage this by not constructing zeros to
-# pass for those indices.
-_grad_fn_accepts_none_for_indices = {
- "SoftmaxCrossEntropyWithLogits": [1],
- "FusedBatchNorm": [1, 2, 3, 4]
-}
-
-
def _record_gradient(op_name, inputs, attrs, results, name):
return pywrap_tensorflow.TFE_Py_RecordGradient(op_name, inputs, attrs,
results, name)
@@ -585,7 +564,10 @@ def _aggregate_grads(gradients):
def _num_elements(grad):
"""The number of elements in the `grad` tensor."""
if isinstance(grad, ops.Tensor):
- return functools.reduce(operator.mul, grad._shape_tuple(), 1) # pylint: disable=protected-access
+ shape_tuple = grad._shape_tuple() # pylint: disable=protected-access
+ if shape_tuple is None or None in shape_tuple:
+ return 0
+ return functools.reduce(operator.mul, shape_tuple, 1)
if isinstance(grad, ops.IndexedSlices):
return functools.reduce(operator.mul, grad.values._shape_tuple(), 1) # pylint: disable=protected-access
raise ValueError("`grad` not a Tensor or IndexedSlices.")
@@ -629,8 +611,9 @@ def _ones(shape, dtype):
_default_vspace = imperative_grad.VSpace(
num_elements_fn=_num_elements,
aggregate_fn=_aggregate_grads,
- zeros=_zeros,
- ones=_ones)
+ zeros_fn=_zeros,
+ ones_fn=_ones,
+ graph_shape_fn=gen_array_ops.shape)
pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace)
@@ -648,8 +631,8 @@ class GradientTape(object):
Operations are recorded if they are executed within this context manager and
at least one of their inputs is being "watched".
- Trainable variables (created by `tf.Variable` or `tf.get_variable`,
- trainable=True is default in both cases) are automatically watched. Tensors
+ Trainable variables (created by `tf.Variable` or `tf.get_variable`, where
+ `trainable=True` is default in both cases) are automatically watched. Tensors
can be manually watched by invoking the `watch` method on this context
manager.
@@ -669,6 +652,7 @@ class GradientTape(object):
```python
x = tf.constant(3.0)
with tf.GradientTape() as g:
+ g.watch(x)
with tf.GradientTape() as gg:
gg.watch(x)
y = x * x
@@ -745,7 +729,9 @@ class GradientTape(object):
self._persistent = persistent
self._watch_accessed_variables = watch_accessed_variables
self._recording = False
- context.context().start_step()
+ self._created_eagerly = context.executing_eagerly()
+ if self._created_eagerly:
+ context.context().start_step()
def __enter__(self):
"""Enters a context inside which operations are recorded on this tape."""
@@ -775,7 +761,8 @@ class GradientTape(object):
self._recording = False
def __del__(self):
- context.context().end_step()
+ if self._created_eagerly:
+ context.context().end_step()
def watch(self, tensor):
"""Ensures that `tensor` is being traced by this tape.
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index f938ed5df8..32731747b7 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -1022,6 +1022,18 @@ class BackpropTest(test.TestCase):
resource_variable_ops.ResourceVariable(2.0))
self.assertAllEqual(gradients_constants, gradients_variables)
+ def testUnknownShapes(self):
+ with context.graph_mode():
+ with backprop.GradientTape() as tape:
+ a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
+ tape.watch(a)
+ b = a**3
+
+ db_da = tape.gradient(b, a)
+
+ with self.cached_session() as sess:
+ self.assertEqual((8.0, 12.0), sess.run((b, db_da), feed_dict={a: 2.0}))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py
new file mode 100644
index 0000000000..8dcacd5c99
--- /dev/null
+++ b/tensorflow/python/eager/def_function.py
@@ -0,0 +1,235 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=unidiomatic-typecheck
+"""Prototype decorator for defining graph-mode functions with eager semantics."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training.checkpointable import base as checkpointable
+
+
+class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable):
+ """Variable which does not lift its initializer out of function context.
+
+ Instances of this variable, when created, build a graph which runs their
+ initializer inside a tf.cond(is_initialized) block.
+
+ This can only be created inside a defun called from (eventually) eager
+ mode. That is, non-function-building graphs are not supported.
+ """
+
+ def __init__(self, # pylint: disable=super-init-not-called
+ initial_value=None,
+ trainable=True,
+ caching_device=None,
+ name=None,
+ dtype=None,
+ constraint=None,
+ **unused_kwargs):
+ """Creates a variable.
+
+ Args:
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called.
+ (Note that initializer functions from init_ops.py must first be bound
+ to a shape before being used here.)
+ trainable: If `True`, GradientTapes automatically watch uses of this
+ Variable.
+ caching_device: Optional device string or function describing where the
+ Variable should be cached for reading. Defaults to the Variable's
+ device. If not `None`, caches on another device. Typical use is to
+ cache on the device where the Ops using the Variable reside, to
+ deduplicate copying through `Switch` and other conditional statements.
+ name: Optional name for the variable. Defaults to `'Variable'` and gets
+ uniquified automatically.
+ dtype: If set, initial_value will be converted to the given type.
+ If None, either the datatype will be kept (if initial_value is
+ a Tensor) or float32 will be used (if it is a Python object convertible
+ to a Tensor).
+ constraint: An optional projection function to be applied to the variable
+ after being updated by an `Optimizer` (e.g. used to implement norm
+ constraints or value constraints for layer weights). The function must
+ take as input the unprojected Tensor representing the value of the
+ variable and return the Tensor for the projected value
+ (which must have the same shape). Constraints are not safe to
+ use when doing asynchronous distributed training.
+
+ Raises:
+ ValueError: If the initial value is not specified, or does not have a
+ shape and `validate_shape` is `True`.
+ RuntimeError: If called outside of a function definition.
+ """
+ if context.executing_eagerly():
+ raise RuntimeError(
+ "UnliftedInitializerVariable should not be created "
+ "outside of functions.")
+ with ops.init_scope():
+ if not context.executing_eagerly():
+ raise RuntimeError(
+ "UnliftedInitializerVariable does not support legacy graph mode.")
+ self._in_graph_mode = False
+ if initial_value is None:
+ raise ValueError("initial_value must be specified.")
+ init_from_fn = callable(initial_value)
+
+ if constraint is not None and not callable(constraint):
+ raise ValueError("The `constraint` argument must be a callable.")
+
+ if isinstance(initial_value, checkpointable.CheckpointInitialValue):
+ self._maybe_initialize_checkpointable()
+ self._update_uid = initial_value.checkpoint_position.restore_uid
+ initial_value = initial_value.wrapped_value
+
+ self._trainable = trainable
+ self._save_slice_info = None
+ self._initial_value = None
+ self._initializer_op = None
+ self._is_initialized_op = None
+ self._graph_element = None
+ self._cached_value = None
+ # Store the graph key so optimizers know how to only retrieve variables from
+ # this graph. Guaranteed to be the same as the eager graph_key.
+ self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
+ with ops.name_scope(name, "Variable", []
+ if init_from_fn else [initial_value]) as name:
+ # pylint: disable=protected-access
+ with ops.init_scope():
+ assert context.executing_eagerly()
+ shared_name = ops._name_from_scope_name(name)
+ shared_name = "%s_%d" % (shared_name, ops.uid())
+ # Use attr_scope and device(None) to simulate the behavior of
+ # colocate_with when the variable we want to colocate with doesn't
+ # yet exist.
+ with ops.name_scope("Initializer"), ops.device(None):
+ initial_value = ops.convert_to_tensor(
+ initial_value() if init_from_fn else initial_value,
+ name="initial_value", dtype=dtype)
+ with ops.init_scope():
+ self._handle = resource_variable_ops.eager_safe_variable_handle(
+ shape=initial_value.get_shape(),
+ dtype=initial_value.dtype.base_dtype,
+ shared_name=shared_name,
+ name=name,
+ graph_mode=False)
+ self._shape = initial_value.shape
+ self._unique_id = shared_name
+ self._handle_name = shared_name + ":0"
+ self._dtype = initial_value.dtype.base_dtype
+ self._constraint = constraint
+ assert initial_value is not None
+ def assign_fn():
+ with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
+ resource_variable_ops.assign_variable_op(
+ self._handle,
+ initial_value,
+ name=n)
+ # Returning values to keep tf.cond happy.
+ return ops.convert_to_tensor(1)
+ def not_assign_fn():
+ return ops.convert_to_tensor(0)
+ # Note: this cond is always guaranteed to run because we're inside a defun
+ # which will insert automatic control dependencies.
+ control_flow_ops.cond(
+ resource_variable_ops.var_is_initialized_op(self._handle),
+ not_assign_fn, assign_fn)
+
+ # After the handle has been created, set up a way to clean it up when
+ # executing eagerly. We'll hold the only reference to the deleter, so that
+ # when this object is garbage collected the deleter will be too. This
+ # means ResourceVariables can be part of reference cycles without those
+ # cycles being uncollectable.
+ self._handle_deleter = resource_variable_ops.EagerResourceDeleter(
+ handle=self._handle, handle_device=self._handle.device)
+ self._cached_shape_as_list = None
+
+
+def _defun_with_scope(scope, fn):
+
+ def wrapped_fn(*args, **kwds):
+ with variable_scope.variable_creator_scope(scope):
+ return fn(*args, **kwds)
+
+ return function.defun(wrapped_fn)
+
+
+def def_function(fn):
+ """Defines a function as per the "functions, not sessions" document."""
+
+ # Wrapping the values in lists to bypass python's lack of way to mutate
+ # symbols from an outer scope.
+ first_call = [True]
+ function_to_call = []
+
+ # TODO(apassos) represent this as an object and not as a closure.
+ def decorated_fn(*args, **kwds):
+ """Graph function for fn."""
+ if not first_call[0]:
+ return function_to_call[0](*args, **kwds)
+
+ first_call[0] = False
+ created_variables = []
+
+ def variable_creator_scope(unused_next_creator, **kwds):
+ """Creates UnliftedInitializerVariables and saves references to them."""
+ v = UnliftedInitializerVariable(**kwds)
+ created_variables.append(v)
+ return v
+
+ first_graph_function = _defun_with_scope(variable_creator_scope, fn)
+
+ # Force the definition of the function for these arguments
+ first_concrete = first_graph_function.get_concrete_function(*args, **kwds)
+
+ def invalid_creator_scope(*unused_args, **unused_kwds):
+ """Disables variable creation."""
+ raise ValueError(
+ "def_function-decorated function tried to create "
+ "variables on second call.")
+
+ second_graph_function = _defun_with_scope(invalid_creator_scope, fn)
+
+ function_to_call.append(second_graph_function)
+ if not created_variables:
+ # Note: this retracing might be unnecessary, but running the function
+ # forever in the scope which disallows variable creation is safer than not
+ # doing so.
+ return second_graph_function(*args, **kwds)
+
+ def fn_with_cond(*inner_args, **inner_kwds):
+ """Conditionally runs initialization if it's needed."""
+ condition = True
+ for variable in created_variables:
+ condition = condition and resource_variable_ops.var_is_initialized_op(
+ variable.handle)
+ # We want to call second_graph_function if possible because it avoids
+ # recomputing potentially expensive initializers.
+ return control_flow_ops.cond(
+ condition,
+ lambda: second_graph_function(*inner_args, **inner_kwds),
+ lambda: first_concrete(*inner_args, **inner_kwds))
+
+ return function.defun(fn_with_cond)(*args, **kwds)
+
+ return decorated_fn
diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py
new file mode 100644
index 0000000000..804436c4bb
--- /dev/null
+++ b/tensorflow/python/eager/def_function_test.py
@@ -0,0 +1,87 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.python.eager import def_function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class DefFunctionTest(test.TestCase):
+
+ def testNoVariables(self):
+
+ @def_function.def_function
+ def fn(x):
+ return 2 * x
+
+ self.assertAllEqual(fn(constant_op.constant(4.0)), 8.0)
+
+ def testFailIfVariablesAreCreatedMoreThanOnce(self):
+
+ @def_function.def_function
+ def fn(x):
+ return variables.Variable(1.0) + x
+
+ with self.assertRaises(ValueError):
+ fn(1.0)
+
+ def testFailIfVariablesAreCreatedMoreThanOnceNoWeakRef(self):
+ state = []
+
+ @def_function.def_function
+ def fn(x):
+ state.append(variables.Variable(1.0))
+ return state[-1] + x
+
+ with self.assertRaises(ValueError):
+ fn(1.0)
+
+ def testCorrectVariableCreation(self):
+
+ state = []
+
+ @def_function.def_function
+ def fn(x):
+ if not state:
+ state.append(variables.Variable(2.0))
+ return state[0] * x
+
+ self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
+ self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)
+
+ def testVariableInitializerNotConstant(self):
+
+ state = []
+
+ @def_function.def_function
+ def fn(x):
+ if not state:
+ state.append(variables.Variable(2.0 * x))
+ return state[0] * x
+
+ self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
+ self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)
+
+
+if __name__ == '__main__':
+ ops.enable_eager_execution()
+ test.main()
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index e2874e25b6..dd3e1a3723 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -23,6 +23,7 @@ import collections
import functools
import sys
import threading
+import weakref
import numpy as np
import six
@@ -72,16 +73,36 @@ def _create_substitute_placeholder(value, name=None, dtype=None):
with ops.control_dependencies(None):
placeholder = graph_placeholder(
dtype=dtype or value.dtype, shape=value.shape, name=name)
- if placeholder.dtype == dtypes_module.resource:
- if isinstance(value, ops.EagerTensor):
- handle_data = value._handle_data # pylint: disable=protected-access
+ _copy_handle_data(value, placeholder)
+ return placeholder
+
+
+def _copy_handle_data(source_t, target_t):
+ """Copies HandleData for variant and resource type tensors if available.
+
+ The CppShapeInferenceResult::HandleData proto contains information about the
+ shapes and types of the element tensors of resource/variant type tensors.
+ We need to copy this across function boundaries, i.e., when capturing a
+ placeholder or when returning a function tensor as output. If we don't do this
+ the element tensors will have unknown shapes, e.g., if a TensorList variant
+ tensor is captured as a placeholder, elements popped from that list would have
+ unknown shape.
+
+ Args:
+ source_t: The tensor to copy HandleData from.
+ target_t: The tensor to copy HandleData to.
+ """
+ if (target_t.dtype == dtypes_module.resource or
+ target_t.dtype == dtypes_module.variant):
+ if isinstance(source_t, ops.EagerTensor):
+ handle_data = source_t._handle_data # pylint: disable=protected-access
else:
- handle_data = resource_variable_ops.get_resource_handle_data(value)
+ handle_data = resource_variable_ops.get_resource_handle_data(source_t)
if handle_data is not None and handle_data.is_set:
# pylint: disable=protected-access
- pywrap_tensorflow.SetResourceHandleShapeAndType(
- placeholder.graph._c_graph, placeholder._as_tf_output(),
- handle_data.SerializeToString())
+ pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
+ target_t._as_tf_output(),
+ handle_data.SerializeToString())
# pylint: enable=protected-access
# Ensure that shapes and dtypes are propagated.
shapes, types = zip(*[(pair.shape, pair.dtype)
@@ -90,12 +111,10 @@ def _create_substitute_placeholder(value, name=None, dtype=None):
shapes = [[d.size for d in s.dim]
if not s.unknown_rank else None for s in shapes]
pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
- placeholder._op._graph._c_graph, # pylint: disable=protected-access
- placeholder._as_tf_output(), # pylint: disable=protected-access
+ target_t._op._graph._c_graph, # pylint: disable=protected-access
+ target_t._as_tf_output(), # pylint: disable=protected-access
shapes, ranks, types)
- return placeholder
-
def _get_device_functions(ctx, graph):
"""Returns a tuple of device functions representing the device stack."""
@@ -180,7 +199,7 @@ class FuncGraph(ops.Graph):
self.inputs = []
self.outputs = []
self.structured_outputs = None
- self.variables = []
+ self._weak_variables = []
self.outer_graph = ops.get_default_graph()
self.captures = collections.OrderedDict()
@@ -217,6 +236,31 @@ class FuncGraph(ops.Graph):
self._graph_key = graph._graph_key
# pylint: enable=protected-access
+ @property
+ def variables(self):
+ """A list of variables accessed by this FuncGraph.
+
+ Note that functions keep only weak references to variables. Calling the
+ function after a variable it accesses has been deleted is an error.
+
+ Yields:
+ Strong references to variables accessed by this FuncGraph.
+ """
+ for weak_v in self._weak_variables:
+ v = weak_v()
+ if v is None:
+ raise AssertionError(
+ "Called a function referencing variables which have been deleted. "
+ "This likely means that function-local variables were created and "
+ "not referenced elsewhere in the program. This is generally a "
+ "mistake; consider storing variables in an object attribute on "
+ "first call.")
+ yield v
+
+ @variables.setter
+ def variables(self, var_list):
+ self._weak_variables = [weakref.ref(v) for v in var_list]
+
def create_op(
self,
op_type,
@@ -409,6 +453,7 @@ class _EagerDefinedFunction(object):
self._num_outputs = len(self.signature.output_arg)
self._output_types = [o.type for o in self.signature.output_arg]
self._output_shapes = [o.shape for o in outputs]
+ self._func_graph_outputs = outputs
self.grad_func_name = None
self.python_grad_func = None
self._c_func = c_api_util.ScopedTFFunction(fn)
@@ -485,6 +530,8 @@ class _EagerDefinedFunction(object):
else:
for i, shape in enumerate(self._output_shapes):
outputs[i].set_shape(shape)
+ for i, func_graph_output in enumerate(self._func_graph_outputs):
+ _copy_handle_data(func_graph_output, outputs[i])
return outputs
@@ -604,11 +651,6 @@ class Function(object):
return self._func_graph
@property
- def variables(self):
- """Returns all variables touched by this function."""
- return self._func_graph.variables
-
- @property
def inputs(self):
"""Returns tensors in `self.graph` corresponding to arguments."""
return self._func_graph.inputs
@@ -805,7 +847,12 @@ def _get_defun_inputs_from_args(args):
return nest.pack_sequence_as(args, function_inputs)
-def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
+def func_graph_from_py_func(name,
+ python_func,
+ args,
+ kwargs,
+ signature=None,
+ func_graph=None):
"""Returns a `FuncGraph` generated from `python_func`.
Args:
@@ -813,13 +860,15 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
python_func: the Python function to trace.
args: the positional args with which the Python function should be called;
ignored if a signature is provided.
- kwds: the keyword args with which the Python function should be called;
+ kwargs: the keyword args with which the Python function should be called;
ignored if a signature is provided.
signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
and dtypes of the arguments. When a signature is provided, `args` and
- `kwds` are ignored, and `python_func` is traced with Tensors conforming
+ `kwargs` are ignored, and `python_func` is traced with Tensors conforming
to `signature`. If `None`, the shapes and dtypes are inferred from the
inputs.
+ func_graph: Optional. An instance of FuncGraph. If provided, we will use
+ this graph else a new one is built and returned.
Returns:
A FuncGraph.
@@ -828,22 +877,25 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
TypeError: If any of `python_func`'s return values is neither `None` nor a
`Tensor`.
"""
- func_graph = FuncGraph(name)
+ if func_graph is None:
+ func_graph = FuncGraph(name)
+ assert isinstance(func_graph, FuncGraph)
with func_graph.as_default(), AutomaticControlDependencies() as a:
variable_scope.get_variable_scope().set_use_resource(True)
if signature is None:
func_args = _get_defun_inputs_from_args(args)
- func_kwds = _get_defun_inputs_from_args(kwds)
+ func_kwargs = _get_defun_inputs_from_args(kwargs)
else:
func_args = _get_defun_inputs_from_signature(signature)
- func_kwds = {}
+ func_kwargs = {}
# Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
# Variables to help check whether mutation happens in calling the function
# Copy the recursive list, tuple and map structure, but not base objects
func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args))
- func_kwds_before = nest.pack_sequence_as(func_kwds, nest.flatten(func_kwds))
+ func_kwargs_before = nest.pack_sequence_as(
+ func_kwargs, nest.flatten(func_kwargs))
def convert(x):
"""Converts an argument to a Tensor."""
@@ -862,7 +914,7 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
this_tape = tape.push_new_tape()
try:
- func_outputs = python_func(*func_args, **func_kwds)
+ func_outputs = python_func(*func_args, **func_kwargs)
# invariant: `func_outputs` contains only Tensors and `None`s.
func_outputs = nest.map_structure(convert, func_outputs)
@@ -882,16 +934,16 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
raise ValueError(errmsg)
check_mutation(func_args_before, func_args)
- check_mutation(func_kwds_before, func_kwds)
+ check_mutation(func_kwargs_before, func_kwargs)
finally:
tape.pop_tape(this_tape)
- # Variables in `func_args`, `func_kwds` should be explicit inputs
+ # Variables in `func_args`, `func_kwargs` should be explicit inputs
# to the function, not captured inputs.
tape_variables = this_tape.watched_variables()
arg_variables = set()
inputs = []
- for arg in nest.flatten(func_args) + nest.flatten(func_kwds):
+ for arg in nest.flatten(func_args) + nest.flatten(func_kwargs):
if isinstance(arg, resource_variable_ops.ResourceVariable):
try:
resource_placeholder = func_graph.captures.pop(arg.handle)
@@ -970,7 +1022,16 @@ def _encode_arg(arg):
return tuple(
(_encode_arg(key), _encode_arg(arg[key])) for key in sorted(arg))
else:
- return arg
+ try:
+ # If possible, keep only a weak reference to Python objects. Weak
+ # references hash to the same value as the original object.
+ # TODO(allenl): Clean up dead functions and their cache keys if the cache
+ # gets large. Right now creating objects with a defunned method, calling
+ # the method, and losing a reference to the object in a loop will leak
+ # memory here.
+ return weakref.ref(arg)
+ except TypeError:
+ return arg
def _deterministic_dict_values(dictionary):
@@ -1013,14 +1074,13 @@ class PolymorphicFunction(object):
if isinstance(python_function, functools.partial):
self._python_function = python_function.func
self._args_to_prepend = python_function.args or tuple()
- self._kwds_to_include = python_function.keywords or {}
+ self._kwargs_to_include = python_function.keywords or {}
else:
self._python_function = python_function
self._args_to_prepend = tuple()
- self._kwds_to_include = {}
+ self._kwargs_to_include = {}
self._name = name
self._function_cache = collections.OrderedDict()
- self._variables = []
self._function_attributes = attributes or {}
self._lock = threading.Lock()
@@ -1056,9 +1116,9 @@ class PolymorphicFunction(object):
self._input_signature = tuple(input_signature)
self._flat_input_signature = tuple(nest.flatten(input_signature))
- def __call__(self, *args, **kwds):
+ def __call__(self, *args, **kwargs):
"""Calls a graph function specialized to the inputs."""
- graph_function, inputs = self._maybe_define_function(*args, **kwds)
+ graph_function, inputs = self._maybe_define_function(args, kwargs)
return graph_function(*inputs)
@property
@@ -1066,12 +1126,6 @@ class PolymorphicFunction(object):
"""Returns the wrapped Python function."""
return self._python_function
- # TODO(akshayka): Remove this property.
- @property
- def variables(self):
- """Returns the union of all variables referenced by cached `Function`s`."""
- return self._variables
-
def get_concrete_function(self, *args, **kwargs):
"""Returns a `Function` object specialized to inputs and execution context.
@@ -1082,7 +1136,7 @@ class PolymorphicFunction(object):
*args: inputs to specialize on.
**kwargs: inputs to specialize on.
"""
- graph_function, _ = self._maybe_define_function(*args, **kwargs)
+ graph_function, _ = self._maybe_define_function(args, kwargs)
return graph_function
def __get__(self, instance, owner):
@@ -1103,33 +1157,37 @@ class PolymorphicFunction(object):
# then `instance` will be `foo` (and `owner` will be `Foo`).
return functools.partial(self.__call__, instance)
- def _cache_key(self, args, kwds, ctx, graph):
+ def _cache_key(self, args, kwargs):
"""Computes the cache key given inputs and execution context."""
if self._input_signature is None:
- inputs = (args, kwds) if kwds else args
+ inputs = (args, kwargs) if kwargs else args
cache_key = tuple(_encode_arg(arg) for arg in inputs)
else:
- del args, kwds
+ del args, kwargs
cache_key = self._flat_input_signature
- # The graph, or whether we're executing eagerly, should be a part of the
- # cache key so we don't improperly capture tensors such as variables.
- executing_eagerly = ctx.executing_eagerly()
- execution_context = executing_eagerly or graph
+ with ops.init_scope():
+ init_graph = ops.get_default_graph()
+
+ # The graph, or whether we're executing eagerly, should be a part of the
+ # cache key so we don't improperly capture tensors such as variables.
+ executing_eagerly = context.executing_eagerly()
+ execution_context = executing_eagerly or init_graph
+ default_graph = ops.get_default_graph()
# Putting the device in the cache key ensures that call-site device
# annotations are respected.
- device_functions = _get_device_functions(ctx, graph)
+ device_functions = _get_device_functions(context.context(), default_graph)
# `ops.colocate_with` directives translate into `ops.device` directives when
# eager execution is enabled.
- colocation_stack = (None if executing_eagerly else
- tuple(graph._colocation_stack.peek_objs())) # pylint: disable=protected-access
+ colocation_stack = (() if executing_eagerly else
+ tuple(default_graph._colocation_stack.peek_objs())) # pylint: disable=protected-access
return cache_key + (execution_context, device_functions, colocation_stack)
- def _canonicalize_function_inputs(self, *args, **kwds):
- """Canonicalizes `args` and `kwds`.
+ def _canonicalize_function_inputs(self, *args, **kwargs):
+ """Canonicalizes `args` and `kwargs`.
Canonicalize the inputs to the Python function using its fullargspec. In
particular, we parse the varags and kwargs that this
@@ -1139,28 +1197,28 @@ class PolymorphicFunction(object):
Args:
*args: The varargs this object was called with.
- **kwds: The keyword args this function was called with.
+ **kwargs: The keyword args this function was called with.
Returns:
A canonicalized ordering of the inputs.
Raises:
- ValueError: If a keyword in `kwds` cannot be matched with a positional
+ ValueError: If a keyword in `kwargs` cannot be matched with a positional
argument when an input signature is specified, or when the inputs
do not conform to the input signature.
"""
args = self._args_to_prepend + args
- kwds = dict(kwds, **self._kwds_to_include)
+ kwargs = dict(kwargs, **self._kwargs_to_include)
# Maps from index of arg to its corresponding value, according to `args`
- # and `kwds`; seeded with the default values for the named args that aren't
- # in `args`.
+ # and `kwargs`; seeded with the default values for the named args that
+ # aren't in `args`.
arg_indices_to_values = {
index: default
for index, default in six.iteritems(self._arg_indices_to_default_values)
if index >= len(args)
}
consumed_args = []
- for arg, value in six.iteritems(kwds):
+ for arg, value in six.iteritems(kwargs):
index = self._args_to_indices.get(arg, None)
if index is not None:
arg_indices_to_values[index] = value
@@ -1170,9 +1228,9 @@ class PolymorphicFunction(object):
"function with keyword arguments when "
"input_signature is provided.")
for arg in consumed_args:
- # After this loop, `kwds` will only contain true keyword arguments, as
+ # After this loop, `kwargs` will only contain true keyword arguments, as
# opposed to named arguments called in a keyword-like fashion.
- kwds.pop(arg)
+ kwargs.pop(arg)
inputs = args + _deterministic_dict_values(arg_indices_to_values)
flat_inputs = nest.flatten(inputs)
@@ -1186,9 +1244,9 @@ class PolymorphicFunction(object):
inputs = nest.pack_sequence_as(structure=inputs,
flat_sequence=flat_inputs)
if self._input_signature is None:
- return inputs, kwds
+ return inputs, kwargs
else:
- assert not kwds
+ assert not kwargs
try:
nest.assert_same_structure(self._input_signature, inputs)
except (ValueError, TypeError):
@@ -1207,25 +1265,27 @@ class PolymorphicFunction(object):
(str(inputs), str(self._input_signature)))
return inputs, {}
- def _maybe_define_function(self, *args, **kwds):
+ def _maybe_define_function(self, args, kwargs):
"""Gets a function for these inputs, defining it if necessary.
+ `args` and `kwargs` can be None if this `PolymorphicFunction` was created
+ with an `input_signature`.
+
Args:
- *args: args for the Python function.
- **kwds: keywords for the Python function.
+ args: The varargs for the Python function.
+ kwargs: The keyword args for the Python function.
Returns:
A graph function corresponding to the input signature implied by args and
- kwds, as well as the inputs that the object should be called with.
+ kwargs, as well as the inputs that the object should be called with.
Raises:
ValueError: If inputs are incompatible with the input signature.
TypeError: If the function inputs include non-hashable objects
"""
-
- args, kwds = self._canonicalize_function_inputs(*args, **kwds)
- cache_key = self._cache_key(args, kwds, context.context(),
- ops.get_default_graph())
+ if self._input_signature is None or args is not None or kwargs is not None:
+ args, kwargs = self._canonicalize_function_inputs(*args, **kwargs)
+ cache_key = self._cache_key(args, kwargs)
with self._lock:
try:
graph_function = self._function_cache.get(cache_key, None)
@@ -1236,13 +1296,11 @@ class PolymorphicFunction(object):
if graph_function is None:
graph_function = Function(
func_graph_from_py_func(self._name, self._python_function, args,
- kwds, self._input_signature),
+ kwargs, self._input_signature),
self._function_attributes)
- self._variables.extend(
- [v for v in graph_function.variables if v not in self._variables])
self._function_cache[cache_key] = graph_function
return graph_function, [
- t for t in nest.flatten((args, kwds))
+ t for t in nest.flatten((args, kwargs))
if isinstance(t, (ops.Tensor, resource_variable_ops.ResourceVariable))
]
@@ -1270,8 +1328,25 @@ def register(func, *args, **kwargs):
"Got type: %s" % type(func))
concrete_func = func.get_concrete_function(*args, **kwargs)
graph = ops.get_default_graph()
- concrete_func._inference_function.add_to_graph(graph) # pylint: disable=protected-access
- # TODO(scottzhu): support concrete_func._backward_graph_function in future.
+
+ # There are two situations for the actual call of a defun:
+ # 1. If none of the input args are resource variables or watch by any tape,
+ # it will run the _inference_function of concrete_func for forward pass, and
+ # the gradient will be generated by standard mechanism.
+ # 2. Otherwise, defun will create two functions, one for forward pass, and the
+ # backward pass will be created via tape.
+ # When registering the function, we put both cases into graph.
+ # pylint: disable=protected-access
+ concrete_func._inference_function.add_to_graph(graph)
+
+ if concrete_func._backward_graph_function is None:
+ concrete_func._construct_backprop_function()
+ forward_function = concrete_func._forward_function
+ backward_function = concrete_func._backward_graph_function._inference_function
+ forward_function.add_to_graph(graph)
+ backward_function.add_to_graph(graph)
+ # pylint: enable=protected-access
+
return concrete_func
@@ -1882,9 +1957,9 @@ def automatic_control_dependencies(f):
The wrapped function.
"""
- def wrapper(*args, **kwds):
+ def wrapper(*args, **kwargs):
with AutomaticControlDependencies() as a:
- result = f(*args, **kwds)
+ result = f(*args, **kwargs)
result_flat = [a.mark_as_return(t) for t in nest.flatten(result)]
return nest.pack_sequence_as(result, result_flat)
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index c168b6060c..34a2648e26 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -21,10 +21,12 @@ import collections
import functools
from multiprocessing.pool import ThreadPool
import sys
+import weakref
import numpy
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import keras
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import backprop
@@ -46,6 +48,7 @@ from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
@@ -74,6 +77,13 @@ class MiniModel(keras_training.Model):
return self.fc(inputs)
+class DefunnedMiniModel(MiniModel):
+
+ @function.defun
+ def call(self, inputs, training=True):
+ return super(DefunnedMiniModel, self).call(inputs, training=training)
+
+
@test_util.with_c_shapes
class FunctionTest(test.TestCase):
@@ -140,8 +150,8 @@ class FunctionTest(test.TestCase):
@function.defun
def f():
- v = resource_variable_ops.ResourceVariable(1.0)
- return v.read_value()
+ self.v = resource_variable_ops.ResourceVariable(1.0)
+ return self.v.read_value()
self.assertAllEqual(f(), 1.0)
@@ -399,9 +409,9 @@ class FunctionTest(test.TestCase):
@function.defun
def tensor_init():
- v = resource_variable_ops.ResourceVariable(
+ self.v = resource_variable_ops.ResourceVariable(
lambda: constant_op.constant(2.0))
- return v.read_value()
+ return self.v.read_value()
value = tensor_init()
if not context.executing_eagerly():
@@ -415,8 +425,8 @@ class FunctionTest(test.TestCase):
def tensor_init():
with ops.init_scope():
const = constant_op.constant(2.0)
- v = resource_variable_ops.ResourceVariable(const)
- return v.read_value()
+ self.v = resource_variable_ops.ResourceVariable(const)
+ return self.v.read_value()
value = tensor_init()
if not context.executing_eagerly():
@@ -429,10 +439,17 @@ class FunctionTest(test.TestCase):
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
- self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+ self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
+ # We do not return v directly since the tensor conversion function of
+ # ResourceVariable returns the read value and not the resource itself.
+ return v._handle
compiled = function.defun(f)
- compiled()
+ var_handle = compiled()
+ self.assertEqual(var_handle.dtype, dtypes.resource)
+ self.assertEqual(var_handle.shape, tensor_shape.scalar())
+ var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
+ self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
def testVariableInLoopInFunction(self):
@@ -456,10 +473,17 @@ class FunctionTest(test.TestCase):
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
- self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+ self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
+ # We do not return v directly since the tensor conversion function of
+ # ResourceVariable returns the read value and not the resource itself.
+ return v._handle
compiled = function.defun(f)
- compiled()
+ var_handle = compiled()
+ self.assertEqual(var_handle.dtype, dtypes.resource)
+ self.assertEqual(var_handle.shape, tensor_shape.scalar())
+ var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
+ self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
def testDefunShapeInferenceWithCapturedVariableInGraphMode(self):
with context.graph_mode():
@@ -468,23 +492,46 @@ class FunctionTest(test.TestCase):
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
- self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+ self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
# Check that shape inference works while creating the defun
compiled = function.defun(f)
compiled()
+ def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self):
+ with context.graph_mode():
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32,
+ element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
+ tensor_list = list_ops.tensor_list_push_back(tensor_list,
+ constant_op.constant(1.0))
+ tensor_list = list_ops.tensor_list_push_back(tensor_list,
+ constant_op.constant(2.0))
+
+ def f():
+ tl, value = list_ops.tensor_list_pop_back(
+ tensor_list, element_dtype=dtypes.float32)
+ self.assertEqual(value.shape, tensor_shape.scalar())
+ return tl
+
+ compiled = function.defun(f)
+ output_tensor_list = compiled()
+ _, value = list_ops.tensor_list_pop_back(
+ output_tensor_list, element_dtype=dtypes.float32)
+ self.assertEqual(value.shape, tensor_shape.scalar())
+
@test_util.run_in_graph_and_eager_modes
def testDefunForcesResourceVariables(self):
def variable_creator():
- return variables.Variable(0.0).read_value()
+ self.v = variables.Variable(0.0)
+ return self.v.read_value()
+ self.v = None
defined = function.defun(variable_creator)
defined() # Create the variable.
- self.assertEqual(len(defined.variables), 1)
self.assertIsInstance(
- defined.variables[0], resource_variable_ops.ResourceVariable)
+ self.v, resource_variable_ops.ResourceVariable)
def testDefunDifferentiable(self):
v = resource_variable_ops.ResourceVariable(1.0)
@@ -1184,13 +1231,11 @@ class FunctionTest(test.TestCase):
defined = function.defun(foo)
x = constant_op.constant([1.0])
- self.assertAllEqual(defined.variables, [])
- _ = defined(x)
- self.assertAllEqual(defined.variables, [v])
+ self.assertEqual(1., self.evaluate(defined(x)))
+ v.assign(2.)
x = constant_op.constant([1.0, 2.0])
- _ = defined(x) # ensure the variables list remains the same
- self.assertAllEqual(defined.variables, [v])
+ self.assertAllEqual([2., 4.], self.evaluate(defined(x)))
def testPythonFunctionWithDefaultArgs(self):
@@ -1557,7 +1602,7 @@ class FunctionTest(test.TestCase):
defun_add = function.defun_with_attributes(
add, attributes={'experimental_3': True, 'experimental_4': 1.0})
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq = matmul(t, t)
@@ -1591,7 +1636,7 @@ class FunctionTest(test.TestCase):
with self.assertRaisesRegexp(ValueError,
'.*Attribute name is not whitelisted.*'):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
matmul(t, t)
@@ -1602,7 +1647,7 @@ class FunctionTest(test.TestCase):
with self.assertRaisesRegexp(ValueError,
'.*Unsupported attribute type.*'):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
add(t, t)
@@ -1624,12 +1669,23 @@ class FunctionTest(test.TestCase):
graph = ops.get_default_graph()
# pylint: disable=protected-access
- self.assertEqual(len(graph._functions), 2)
+ self.assertEqual(len(graph._functions), 6)
+ # two sets of functions, each of them are (inference, forward, backward)
functions = list(graph._functions.values())
- pre_register_matmul_func_name = functions[0].definition.signature.name
- self.assertRegexpMatches(pre_register_matmul_func_name, '.*matmul.*')
- pre_register_add_func_name = functions[1].definition.signature.name
- self.assertRegexpMatches(pre_register_add_func_name, '.*add.*')
+ captured_function_names = [
+ f.definition.signature.name for f in functions
+ ]
+ expected_func_name_regex = [
+ '.*inference.*matmul.*',
+ '.*forward.*matmul.*',
+ '.*inference.*backward.*matmul.*',
+ '.*inference.*add.*',
+ '.*forward.*add.*',
+ '.*inference.*backward.*add.*',
+ ]
+ for i in range(len(functions)):
+ self.assertRegexpMatches(captured_function_names[i],
+ expected_func_name_regex[i])
sq = defun_matmul(t, t)
double = add(t, t)
@@ -1637,12 +1693,11 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
# Make sure the pre registered function is used, and no other function
# is added.
- self.assertEqual(len(graph._functions), 2)
+ self.assertEqual(len(graph._functions), 6)
functions = list(graph._functions.values())
- called_func_name = functions[0].definition.signature.name
- self.assertEqual(pre_register_matmul_func_name, called_func_name)
- called_func_name = functions[1].definition.signature.name
- self.assertEqual(pre_register_add_func_name, called_func_name)
+ for i in range(len(functions)):
+ self.assertEquals(captured_function_names[i],
+ functions[i].definition.signature.name)
def testRegisterFunctionWithInputSignature(self):
def matmul(x, y):
@@ -1660,7 +1715,7 @@ class FunctionTest(test.TestCase):
graph = ops.get_default_graph()
# pylint: disable=protected-access
- self.assertEqual(len(graph._functions), 1)
+ self.assertEqual(len(graph._functions), 3)
# Test input param shape mismatch
t2 = constant_op.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
@@ -1683,7 +1738,7 @@ class FunctionTest(test.TestCase):
graph = ops.get_default_graph()
# Only one function is registered since the input param are in same type
# pylint: disable=protected-access
- self.assertEqual(len(graph._functions), 1)
+ self.assertEqual(len(graph._functions), 3)
def testCallingFunctionWithDifferentVariables(self):
@@ -1722,6 +1777,82 @@ class FunctionTest(test.TestCase):
'be Tensors;.*'):
graph_function('Not a Tensor.')
+ # TODO(scottzhu): Revive the test once the grappler plugin is updated.
+ def disabled_testSwapImplementationWithGrapplerPlugin(self):
+ rewrites = rewriter_config_pb2.RewriterConfig()
+ # function_optimizer has to be turn off, otherwise it will delete the
+ # registered function if it does not get called.
+ # TODO(scottzhu): Move the ExperimentalImplementationSelector to be called
+ # before function_optimizer in future.
+ rewrites.function_optimization = rewriter_config_pb2.RewriterConfig.OFF
+ customer_optimizer = rewrites.custom_optimizers.add()
+ customer_optimizer.name = 'ExperimentalImplementationSelector'
+ rewrites.min_graph_nodes = -1
+ graph_options = config_pb2.GraphOptions(
+ rewrite_options=rewrites, build_cost_model=1)
+ config = config_pb2.ConfigProto(graph_options=graph_options)
+
+ with context.graph_mode(), self.cached_session(
+ config=config, graph=ops.Graph(), use_gpu=True) as sess:
+
+ @function.defun_with_attributes(
+ attributes={
+ 'experimental_api_implements': 'random_boost',
+ 'experimental_api_preferred_device': 'CPU'
+ })
+ def cpu_boost(x):
+ return math_ops.add(x, 2.0)
+
+ @function.defun_with_attributes(
+ attributes={
+ 'experimental_api_implements': 'random_boost',
+ 'experimental_api_preferred_device': 'GPU'
+ })
+ def gpu_boost(x):
+ return math_ops.add(x, 4.0)
+
+ x = constant_op.constant(1.0)
+
+ function.register(cpu_boost, x)
+ y = gpu_boost(x)
+ y_value = sess.run(y)
+
+ if test.is_gpu_available():
+ self.assertEquals(y_value, 5.0)
+ else:
+ # Grappler fallback to use the CPU impl even called with GPU function.
+ self.assertEquals(y_value, 3.0)
+
+ def testDefunFunctionSeparateGraphs(self):
+ with context.graph_mode():
+
+ @function.defun
+ def add(x):
+ return x + 5
+
+ @function.defun
+ def maybe_add(x, should_add):
+ if should_add:
+ return add(x)
+ else:
+ return x
+
+ with ops.Graph().as_default():
+ x = constant_op.constant(11)
+ maybe_add(x, True)
+ self.assertEqual(len(maybe_add._function_cache), 1)
+ self.assertEqual(len(add._function_cache), 1)
+
+ maybe_add(x, False)
+ self.assertEqual(len(maybe_add._function_cache), 2)
+ self.assertEqual(len(add._function_cache), 1)
+
+ with ops.Graph().as_default():
+ x = constant_op.constant(11)
+ maybe_add(x, True)
+ self.assertEqual(len(maybe_add._function_cache), 3)
+ self.assertEqual(len(add._function_cache), 2)
+
@test_util.with_c_shapes
class AutomaticControlDependenciesTest(test.TestCase):
@@ -1913,10 +2044,10 @@ class AutomaticControlDependenciesTest(test.TestCase):
@function.defun
def train():
- v = resource_variable_ops.ResourceVariable(1.0)
- grad = backprop.implicit_grad(loss)(v)
+ self.v = resource_variable_ops.ResourceVariable(1.0)
+ grad = backprop.implicit_grad(loss)(self.v)
optimizer.apply_gradients(grad)
- return v.read_value()
+ return self.v.read_value()
value = train()
self.assertEqual(value.numpy(), -1.0)
@@ -1943,10 +2074,10 @@ class AutomaticControlDependenciesTest(test.TestCase):
@function.defun
def train():
- v = resource_variable_ops.ResourceVariable(1.0)
- grad = backprop.implicit_grad(loss)(v)
+ self.v = resource_variable_ops.ResourceVariable(1.0)
+ grad = backprop.implicit_grad(loss)(self.v)
optimizer.apply_gradients(grad)
- return v.read_value()
+ return self.v.read_value()
train()
@@ -2133,6 +2264,13 @@ class AutomaticControlDependenciesTest(test.TestCase):
modify_same_flat(nested_input)
+ def testDecoratedMethodVariableCleanup(self):
+ m = DefunnedMiniModel()
+ m(array_ops.ones([1, 2]))
+ weak_variables = weakref.WeakSet(m.variables)
+ self.assertEqual(2, len(weak_variables))
+ del m
+ self.assertEqual([], list(weak_variables))
if __name__ == '__main__':
ops.enable_eager_execution(
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index 5f027d107c..5f5af4ab6c 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -23,8 +23,9 @@ import collections
from tensorflow.python import pywrap_tensorflow
-VSpace = collections.namedtuple(
- "VSpace", ["aggregate_fn", "num_elements_fn", "zeros", "ones"])
+VSpace = collections.namedtuple("VSpace", [
+ "aggregate_fn", "num_elements_fn", "zeros_fn", "ones_fn", "graph_shape_fn"
+])
def imperative_grad(
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index f34ce6af79..5f44bd4fec 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -516,25 +516,13 @@ static PyObject* EagerTensor_rank(EagerTensor* self) {
// Getter for `_num_elements`.
static PyObject* EagerTensor_num_elements(EagerTensor* self) {
auto handle = self->handle;
- int n = TFE_TensorHandleNumDims(handle, self->status);
+ int n = TFE_TensorHandleNumElements(handle, self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
// Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, "");
return nullptr;
}
- tensorflow::int64 value = 1;
- if (PyErr_Occurred()) return nullptr;
- for (int i = 0; i < n; ++i) {
- int64_t dim = TFE_TensorHandleDim(handle, i, self->status);
- if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
- // Cleanup self->status before returning.
- TF_SetStatus(self->status, TF_OK, "");
- PyErr_SetString(PyExc_RuntimeError, "Error while iterating dimensions");
- return nullptr;
- }
- value *= dim;
- }
- return PyLong_FromLongLong(value);
+ return PyLong_FromLongLong(n);
}
static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) {
@@ -777,17 +765,34 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
return reinterpret_cast<PyObject*>(t);
}
-tensorflow::int64 EagerTensor_id(const PyObject* tensor) {
- CHECK(EagerTensor_CheckExact(tensor));
+tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor) {
+ DCHECK(EagerTensor_CheckExact(tensor));
return reinterpret_cast<const EagerTensor*>(tensor)->id;
}
-tensorflow::DataType EagerTensor_dtype(const PyObject* tensor) {
- CHECK(EagerTensor_CheckExact(tensor));
+tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) {
+ DCHECK(EagerTensor_CheckExact(tensor));
return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(
reinterpret_cast<const EagerTensor*>(tensor)->handle));
}
+tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor) {
+ DCHECK(EagerTensor_CheckExact(tensor));
+ const EagerTensor* as_c_eager_tensor =
+ reinterpret_cast<const EagerTensor*>(tensor);
+ tensorflow::int64 result = TFE_TensorHandleNumElements(
+ as_c_eager_tensor->handle, as_c_eager_tensor->status);
+
+ if (MaybeRaiseExceptionFromTFStatus(as_c_eager_tensor->status,
+ PyExc_ValueError)) {
+ // Cleanup status before returning.
+ TF_SetStatus(as_c_eager_tensor->status, TF_OK, "");
+ return -1;
+ }
+
+ return result;
+}
+
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
if (!PyType_Check(base_class)) {
PyErr_SetString(
diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h
index bc042eb19e..4eaa1ba536 100644
--- a/tensorflow/python/eager/pywrap_tensor.h
+++ b/tensorflow/python/eager/pywrap_tensor.h
@@ -21,8 +21,9 @@ limitations under the License.
#include "tensorflow/python/lib/core/numpy.h"
bool EagerTensor_CheckExact(const PyObject* o);
-tensorflow::int64 EagerTensor_id(const PyObject* tensor);
-tensorflow::DataType EagerTensor_dtype(const PyObject* tensor);
+tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor);
+tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor);
+tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor);
namespace tensorflow {
TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype);
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 9f2f4e06ad..196e20e4d7 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/python/eager/pywrap_tfe.h"
+#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h"
@@ -860,7 +861,7 @@ static tensorflow::int64 MakeInt(PyObject* integer) {
static tensorflow::int64 FastTensorId(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
- return EagerTensor_id(tensor);
+ return PyEagerTensor_ID(tensor);
}
PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
if (id_field == nullptr) {
@@ -873,7 +874,7 @@ static tensorflow::int64 FastTensorId(PyObject* tensor) {
static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
- return EagerTensor_dtype(tensor);
+ return PyEagerTensor_Dtype(tensor);
}
PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype");
if (dtype_field == nullptr) {
@@ -889,12 +890,239 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
return static_cast<tensorflow::DataType>(id);
}
+class PyTapeTensor {
+ public:
+ PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
+ const tensorflow::TensorShape& shape)
+ : id_(id), dtype_(dtype), shape_(shape) {}
+ PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
+ PyObject* shape)
+ : id_(id), dtype_(dtype), shape_(shape) {
+ Py_INCREF(absl::get<1>(shape_));
+ }
+ PyTapeTensor(const PyTapeTensor& other) {
+ id_ = other.id_;
+ dtype_ = other.dtype_;
+ shape_ = other.shape_;
+ if (shape_.index() == 1) {
+ Py_INCREF(absl::get<1>(shape_));
+ }
+ }
+
+ ~PyTapeTensor() {
+ if (shape_.index() == 1) {
+ Py_DECREF(absl::get<1>(shape_));
+ }
+ }
+ PyObject* GetShape() const;
+ PyObject* GetDType() const { return PyLong_FromLong(dtype_); }
+ tensorflow::int64 GetID() const { return id_; }
+
+ private:
+ tensorflow::int64 id_;
+ tensorflow::DataType dtype_;
+ absl::variant<tensorflow::TensorShape, PyObject*> shape_;
+};
+
+class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
+ PyTapeTensor> {
+ public:
+ explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
+ Py_INCREF(py_vspace_);
+ }
+
+ tensorflow::Status Initialize() {
+ num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
+ if (num_elements_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
+ if (aggregate_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn");
+ if (zeros_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn");
+ if (ones_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn");
+ if (graph_shape_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ return tensorflow::Status::OK();
+ }
+
+ ~PyVSpace() override {
+ Py_XDECREF(num_elements_);
+ Py_XDECREF(aggregate_fn_);
+ Py_XDECREF(zeros_fn_);
+ Py_XDECREF(ones_fn_);
+ Py_XDECREF(graph_shape_fn_);
+
+ Py_DECREF(py_vspace_);
+ }
+
+ tensorflow::int64 NumElements(PyObject* tensor) const final {
+ if (EagerTensor_CheckExact(tensor)) {
+ return PyEagerTensor_NumElements(tensor);
+ }
+ PyObject* arglist =
+ Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
+ PyObject* result = PyEval_CallObject(num_elements_, arglist);
+ Py_DECREF(arglist);
+ if (result == nullptr) {
+ // The caller detects whether a python exception has been raised.
+ return -1;
+ }
+ tensorflow::int64 r = MakeInt(result);
+ Py_DECREF(result);
+ return r;
+ }
+
+ PyObject* AggregateGradients(
+ tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
+ PyObject* list = PyList_New(gradient_tensors.size());
+ for (int i = 0; i < gradient_tensors.size(); ++i) {
+ // Note: stealing a reference to the gradient tensors.
+ CHECK(gradient_tensors[i] != nullptr);
+ CHECK(gradient_tensors[i] != Py_None);
+ PyList_SET_ITEM(list, i,
+ reinterpret_cast<PyObject*>(gradient_tensors[i]));
+ }
+ PyObject* arglist = Py_BuildValue("(O)", list);
+ CHECK(arglist != nullptr);
+ PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
+ Py_DECREF(arglist);
+ Py_DECREF(list);
+ return result;
+ }
+
+ void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
+
+ PyObject* Zeros(const PyTapeTensor& tensor) const final {
+ PyObject* py_shape = tensor.GetShape();
+ PyObject* py_dtype = tensor.GetDType();
+ PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
+ PyObject* result = PyEval_CallObject(zeros_fn_, arg_list);
+ Py_DECREF(arg_list);
+ Py_DECREF(py_dtype);
+ Py_DECREF(py_shape);
+ return reinterpret_cast<PyObject*>(result);
+ }
+
+ PyObject* Ones(const PyTapeTensor& tensor) const final {
+ PyObject* py_shape = tensor.GetShape();
+ PyObject* py_dtype = tensor.GetDType();
+ PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
+ PyObject* result = PyEval_CallObject(ones_fn_, arg_list);
+ Py_DECREF(arg_list);
+ Py_DECREF(py_dtype);
+ Py_DECREF(py_shape);
+ return result;
+ }
+
+ PyObject* GraphShape(PyObject* tensor) const {
+ PyObject* arg_list = Py_BuildValue("(O)", tensor);
+ PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list);
+ Py_DECREF(arg_list);
+ return result;
+ }
+
+ tensorflow::Status CallBackwardFunction(
+ PyBackwardFunction* backward_function,
+ tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
+ std::vector<PyObject*>* result) const final {
+ PyObject* grads = PyTuple_New(output_gradients.size());
+ for (int i = 0; i < output_gradients.size(); ++i) {
+ if (output_gradients[i] == nullptr) {
+ Py_INCREF(Py_None);
+ PyTuple_SET_ITEM(grads, i, Py_None);
+ } else {
+ PyTuple_SET_ITEM(grads, i,
+ reinterpret_cast<PyObject*>(output_gradients[i]));
+ }
+ }
+ PyObject* py_result = (*backward_function)(grads);
+ Py_DECREF(grads);
+ if (py_result == nullptr) {
+ return tensorflow::errors::Internal("gradient function threw exceptions");
+ }
+ result->clear();
+ PyObject* seq =
+ PySequence_Fast(py_result, "expected a sequence of gradients");
+ if (seq == nullptr) {
+ return tensorflow::errors::InvalidArgument(
+ "gradient function did not return a list");
+ }
+ int len = PySequence_Fast_GET_SIZE(seq);
+ VLOG(1) << "Gradient length is " << len;
+ result->reserve(len);
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
+ if (item == Py_None) {
+ result->push_back(nullptr);
+ } else {
+ Py_INCREF(item);
+ result->push_back(item);
+ }
+ }
+ Py_DECREF(seq);
+ Py_DECREF(py_result);
+ return tensorflow::Status::OK();
+ }
+
+ void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
+
+ private:
+ PyObject* py_vspace_;
+
+ PyObject* num_elements_;
+ PyObject* aggregate_fn_;
+ PyObject* zeros_fn_;
+ PyObject* ones_fn_;
+ PyObject* graph_shape_fn_;
+};
+PyVSpace* py_vspace = nullptr;
+
+PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
+ if (py_vspace != nullptr) {
+ delete py_vspace;
+ }
+
+ py_vspace = new PyVSpace(e);
+ auto status = py_vspace->Initialize();
+ if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
+ delete py_vspace;
+ return nullptr;
+ }
+
+ Py_RETURN_NONE;
+}
+
+PyObject* PyTapeTensor::GetShape() const {
+ if (shape_.index() == 0) {
+ auto& shape = absl::get<0>(shape_);
+ PyObject* py_shape = PyTuple_New(shape.dims());
+ for (int i = 0; i < shape.dims(); ++i) {
+ PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
+ }
+
+ return py_shape;
+ }
+
+ return py_vspace->GraphShape(absl::get<1>(shape_));
+}
+
class GradientTape
- : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> {
+ : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
+ PyTapeTensor> {
public:
explicit GradientTape(bool persistent, bool watch_accessed_variables)
- : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction>(
- persistent),
+ : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
+ PyTapeTensor>(persistent),
watch_accessed_variables_(watch_accessed_variables) {}
virtual ~GradientTape() {
@@ -1175,24 +1403,41 @@ void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
}
-static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
+bool ListContainsNone(PyObject* list) {
+ if (list == Py_None) return true;
+ tensorflow::Safe_PyObjectPtr seq(
+ PySequence_Fast(list, "expected a sequence"));
+ if (seq == nullptr) {
+ return false;
+ }
+
+ int len = PySequence_Size(list);
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);
+ if (item == Py_None) return true;
+ }
+
+ return false;
+}
+
+static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
TFE_TensorHandle* t = EagerTensor_Handle(tensor);
- tensorflow::int64 id = EagerTensor_id(tensor);
+ tensorflow::int64 id = PyEagerTensor_ID(tensor);
tensorflow::TensorShape tensor_shape;
const tensorflow::Status status = t->handle->Shape(&tensor_shape);
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
- return tensorflow::eager::TapeTensor{id, t->handle->dtype,
- tensorflow::TensorShape({})};
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
} else {
- return tensorflow::eager::TapeTensor{id, t->handle->dtype, tensor_shape};
+ return PyTapeTensor(id, t->handle->dtype, tensor_shape);
}
}
tensorflow::int64 id = FastTensorId(tensor);
if (PyErr_Occurred()) {
- return tensorflow::eager::TapeTensor{
- id, static_cast<tensorflow::DataType>(0), tensorflow::TensorShape({})};
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
}
PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
@@ -1200,16 +1445,21 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
tensorflow::DataType dtype =
static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
Py_DECREF(dtype_enum);
- if (PyErr_Occurred() != nullptr) {
- return tensorflow::eager::TapeTensor{id, dtype,
- tensorflow::TensorShape({})};
+ if (PyErr_Occurred()) {
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
}
static char _shape_tuple[] = "_shape_tuple";
PyObject* shape_tuple = PyObject_CallMethod(tensor, _shape_tuple, nullptr);
- if (PyErr_Occurred() != nullptr) {
- return tensorflow::eager::TapeTensor{id, dtype,
- tensorflow::TensorShape({})};
+ if (PyErr_Occurred()) {
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
}
+
+ if (ListContainsNone(shape_tuple)) {
+ return PyTapeTensor(id, dtype, tensor);
+ }
+
auto l = MakeIntList(shape_tuple);
Py_DECREF(shape_tuple);
// Replace -1, which represents accidental Nones which can occur in graph mode
@@ -1220,7 +1470,7 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
}
}
tensorflow::TensorShape shape(l);
- return tensorflow::eager::TapeTensor{id, dtype, shape};
+ return PyTapeTensor(id, dtype, shape);
}
std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
@@ -1286,7 +1536,7 @@ void TapeSetRecordOperation(
const std::vector<tensorflow::DataType>& input_dtypes,
const std::function<PyBackwardFunction*()>& backward_function_getter,
const std::function<void(PyBackwardFunction*)>& backward_function_killer) {
- std::vector<tensorflow::eager::TapeTensor> output_info;
+ std::vector<PyTapeTensor> output_info;
PyObject* seq = PySequence_Fast(output_tensors,
"expected a sequence of integer tensor ids");
int len = PySequence_Size(output_tensors);
@@ -1362,177 +1612,6 @@ void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
}
}
-class PyVSpace
- : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction> {
- public:
- explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
- Py_INCREF(py_vspace_);
- }
-
- tensorflow::Status Initialize() {
- num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
- if (num_elements_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
- if (aggregate_fn_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- zeros_ = PyObject_GetAttrString(py_vspace_, "zeros");
- if (zeros_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- ones_ =
- PyObject_GetAttrString(reinterpret_cast<PyObject*>(py_vspace_), "ones");
- if (ones_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- return tensorflow::Status::OK();
- }
-
- ~PyVSpace() override {
- Py_XDECREF(num_elements_);
- Py_XDECREF(aggregate_fn_);
- Py_XDECREF(zeros_);
- Py_XDECREF(ones_);
-
- Py_DECREF(py_vspace_);
- }
-
- tensorflow::int64 NumElements(PyObject* tensor) const final {
- PyObject* arglist =
- Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
- PyObject* result = PyEval_CallObject(num_elements_, arglist);
- Py_DECREF(arglist);
- if (result == nullptr) {
- // The caller detects whether a python exception has been raised.
- return -1;
- }
- tensorflow::int64 r = MakeInt(result);
- Py_DECREF(result);
- return r;
- }
-
- PyObject* AggregateGradients(
- tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
- PyObject* list = PyList_New(gradient_tensors.size());
- for (int i = 0; i < gradient_tensors.size(); ++i) {
- // Note: stealing a reference to the gradient tensors.
- CHECK(gradient_tensors[i] != nullptr);
- CHECK(gradient_tensors[i] != Py_None);
- PyList_SET_ITEM(list, i,
- reinterpret_cast<PyObject*>(gradient_tensors[i]));
- }
- PyObject* arglist = Py_BuildValue("(O)", list);
- CHECK(arglist != nullptr);
- PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
- Py_DECREF(arglist);
- Py_DECREF(list);
- return result;
- }
-
- void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
-
- PyObject* Zeros(tensorflow::TensorShape shape,
- tensorflow::DataType dtype) const final {
- PyObject* py_shape = PyTuple_New(shape.dims());
- for (int i = 0; i < shape.dims(); ++i) {
- PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
- }
- PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
- PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
- PyObject* result = PyEval_CallObject(zeros_, arg_list);
- Py_DECREF(arg_list);
- Py_DECREF(py_dtype);
- Py_DECREF(py_shape);
- return reinterpret_cast<PyObject*>(result);
- }
-
- PyObject* Ones(tensorflow::TensorShape shape,
- tensorflow::DataType dtype) const final {
- PyObject* py_shape = PyTuple_New(shape.dims());
- for (int i = 0; i < shape.dims(); ++i) {
- PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
- }
- PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
- PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
- PyObject* result = PyEval_CallObject(ones_, arg_list);
- Py_DECREF(arg_list);
- Py_DECREF(py_dtype);
- Py_DECREF(py_shape);
- return result;
- }
-
- tensorflow::Status CallBackwardFunction(
- PyBackwardFunction* backward_function,
- tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
- std::vector<PyObject*>* result) const final {
- PyObject* grads = PyTuple_New(output_gradients.size());
- for (int i = 0; i < output_gradients.size(); ++i) {
- if (output_gradients[i] == nullptr) {
- Py_INCREF(Py_None);
- PyTuple_SET_ITEM(grads, i, Py_None);
- } else {
- PyTuple_SET_ITEM(grads, i,
- reinterpret_cast<PyObject*>(output_gradients[i]));
- }
- }
- PyObject* py_result = (*backward_function)(grads);
- Py_DECREF(grads);
- if (py_result == nullptr) {
- return tensorflow::errors::Internal("gradient function threw exceptions");
- }
- result->clear();
- PyObject* seq =
- PySequence_Fast(py_result, "expected a sequence of gradients");
- if (seq == nullptr) {
- return tensorflow::errors::InvalidArgument(
- "gradient function did not return a list");
- }
- int len = PySequence_Fast_GET_SIZE(seq);
- VLOG(1) << "Gradient length is " << len;
- result->reserve(len);
- for (int i = 0; i < len; ++i) {
- PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
- if (item == Py_None) {
- result->push_back(nullptr);
- } else {
- Py_INCREF(item);
- result->push_back(item);
- }
- }
- Py_DECREF(seq);
- Py_DECREF(py_result);
- return tensorflow::Status::OK();
- }
-
- void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
-
- private:
- PyObject* py_vspace_;
-
- PyObject* num_elements_;
- PyObject* aggregate_fn_;
- PyObject* zeros_;
- PyObject* ones_;
-};
-PyVSpace* py_vspace = nullptr;
-
-PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
- if (py_vspace != nullptr) {
- delete py_vspace;
- }
-
- py_vspace = new PyVSpace(e);
- auto status = py_vspace->Initialize();
- if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
- delete py_vspace;
- return nullptr;
- }
-
- Py_RETURN_NONE;
-}
-
std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
if (seq == nullptr) {
@@ -1744,6 +1823,9 @@ PyObject* MaybeGetDTypeForAttr(const string& attr,
Py_RETURN_NONE;
}
+// TODO(agarwal): use an automatic mechanism for handling None arguments to
+// gradient functions.
+
// Returns a pair where the first value of the pair indicates whether or not all
// outputs are unused. If the first value is false, the second value is a
// set that identifies which of the output indices are unused.
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index bfcc019dd5..ba1b7ec2b5 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -197,6 +197,7 @@ py_library(
srcs = ["canned/boosted_trees.py"],
srcs_version = "PY2AND3",
deps = [
+ ":boosted_trees_utils",
":estimator",
":head",
":model_fn",
@@ -224,6 +225,35 @@ py_test(
)
py_library(
+ name = "boosted_trees_utils",
+ srcs = ["canned/boosted_trees_utils.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":estimator",
+ ":head",
+ ":model_fn",
+ "//tensorflow:tensorflow_py_no_contrib",
+ ],
+)
+
+py_test(
+ name = "boosted_trees_utils_test",
+ size = "medium",
+ srcs = ["canned/boosted_trees_utils_test.py"],
+ shard_count = 2,
+ srcs_version = "PY2AND3",
+ tags = [
+ "optonly",
+ ],
+ deps = [
+ ":boosted_trees",
+ ":inputs",
+ "//tensorflow:tensorflow_py_no_contrib",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
name = "dnn",
srcs = ["canned/dnn.py"],
srcs_version = "PY2AND3",
@@ -251,6 +281,7 @@ py_library(
":prediction_keys",
"//tensorflow:tensorflow_py_no_contrib",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
)
@@ -273,6 +304,7 @@ py_test(
":pandas_io",
":prediction_keys",
"//tensorflow:tensorflow_py_no_contrib",
+ "@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
)
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 19f18015e4..0278990cfc 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -21,8 +21,12 @@ import abc
import collections
import functools
+import numpy as np
+
+from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
from tensorflow.python.estimator import estimator
-from tensorflow.python.estimator import model_fn
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.canned import boosted_trees_utils
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.feature_column import feature_column as feature_column_lib
from tensorflow.python.framework import dtypes
@@ -36,8 +40,10 @@ from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.array_ops import identity as tf_identity
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
+from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.util.tf_export import estimator_export
@@ -191,14 +197,50 @@ def _calculate_num_features(sorted_feature_columns):
return num_features
+def _generate_feature_name_mapping(sorted_feature_columns):
+ """Return a list of feature name for feature ids.
+
+ Args:
+ sorted_feature_columns: a list/set of tf.feature_column sorted by name.
+
+ Returns:
+ feature_name_mapping: a list of feature names indexed by the feature ids.
+
+ Raises:
+ ValueError: when unsupported features/columns are tried.
+ """
+ names = []
+ for column in sorted_feature_columns:
+ if isinstance(column, feature_column_lib._IndicatorColumn): # pylint:disable=protected-access
+ categorical_column = column.categorical_column
+ if isinstance(categorical_column,
+ feature_column_lib._VocabularyListCategoricalColumn): # pylint:disable=protected-access
+ for value in categorical_column.vocabulary_list:
+ names.append('{}:{}'.format(column.name, value))
+ elif isinstance(categorical_column,
+ feature_column_lib._BucketizedColumn): # pylint:disable=protected-access
+ boundaries = [-np.inf] + list(categorical_column.boundaries) + [np.inf]
+ for pair in zip(boundaries[:-1], boundaries[1:]):
+ names.append('{}:{}'.format(column.name, pair))
+ else:
+ for num in range(categorical_column._num_buckets): # pylint:disable=protected-access
+ names.append('{}:{}'.format(column.name, num))
+ elif isinstance(column, feature_column_lib._BucketizedColumn):
+ names.append(column.name)
+ else:
+ raise ValueError(
+ 'For now, only bucketized_column and indicator_column is supported '
+ 'but got: {}'.format(column))
+ return names
+
+
def _cache_transformed_features(features, sorted_feature_columns, batch_size):
"""Transform features and cache, then returns (cached_features, cache_op)."""
num_features = _calculate_num_features(sorted_feature_columns)
cached_features = [
_local_variable(
array_ops.zeros([batch_size], dtype=dtypes.int32),
- name='cached_feature_{}'.format(i))
- for i in range(num_features)
+ name='cached_feature_{}'.format(i)) for i in range(num_features)
]
are_features_cached = _local_variable(False, name='are_features_cached')
@@ -228,8 +270,7 @@ def _cache_transformed_features(features, sorted_feature_columns, batch_size):
return cached, cache_flip_op
input_feature_list, cache_flip_op = control_flow_ops.cond(
- are_features_cached,
- lambda: (cached_features, control_flow_ops.no_op()),
+ are_features_cached, lambda: (cached_features, control_flow_ops.no_op()),
cache_features_and_return)
return input_feature_list, cache_flip_op
@@ -263,8 +304,8 @@ class _CacheTrainingStatesUsingHashTable(object):
elif dtypes.as_dtype(dtypes.string).is_compatible_with(example_ids.dtype):
empty_key = ''
else:
- raise ValueError('Unsupported example_id_feature dtype %s.' %
- example_ids.dtype)
+ raise ValueError(
+ 'Unsupported example_id_feature dtype %s.' % example_ids.dtype)
# Cache holds latest <tree_id, node_id, logits> for each example.
# tree_id and node_id are both int32 but logits is a float32.
# To reduce the overhead, we store all of them together as float32 and
@@ -273,8 +314,8 @@ class _CacheTrainingStatesUsingHashTable(object):
empty_key=empty_key, value_dtype=dtypes.float32, value_shape=[3])
self._example_ids = ops.convert_to_tensor(example_ids)
if self._example_ids.shape.ndims not in (None, 1):
- raise ValueError('example_id should have rank 1, but got %s' %
- self._example_ids)
+ raise ValueError(
+ 'example_id should have rank 1, but got %s' % self._example_ids)
self._logits_dimension = logits_dimension
def lookup(self):
@@ -334,7 +375,7 @@ class _CacheTrainingStatesUsingVariables(object):
array_ops.zeros([batch_size], dtype=dtypes.int32),
name='tree_ids_cache')
self._node_ids = _local_variable(
- _DUMMY_NODE_ID*array_ops.ones([batch_size], dtype=dtypes.int32),
+ _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32),
name='node_ids_cache')
self._logits = _local_variable(
array_ops.zeros([batch_size, logits_dimension], dtype=dtypes.float32),
@@ -422,9 +463,13 @@ class _EnsembleGrower(object):
self._pruning_mode_parsed = boosted_trees_ops.PruningMode.from_str(
tree_hparams.pruning_mode)
- if (self._pruning_mode_parsed != boosted_trees_ops.PruningMode.NO_PRUNING
- and tree_hparams.tree_complexity <= 0):
- raise ValueError('For pruning, tree_complexity must be positive.')
+ if tree_hparams.tree_complexity > 0:
+ if self._pruning_mode_parsed == boosted_trees_ops.PruningMode.NO_PRUNING:
+ raise ValueError(
+ 'Tree complexity have no effect unless pruning mode is chosen.')
+ else:
+ if self._pruning_mode_parsed != boosted_trees_ops.PruningMode.NO_PRUNING:
+ raise ValueError('For pruning, tree_complexity must be positive.')
# pylint: enable=protected-access
@abc.abstractmethod
@@ -719,7 +764,7 @@ def _bt_model_fn(
tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
# Create logits.
- if mode != model_fn.ModeKeys.TRAIN:
+ if mode != model_fn_lib.ModeKeys.TRAIN:
input_feature_list = _get_transformed_features(features,
sorted_feature_columns)
logits = boosted_trees_ops.predict(
@@ -886,6 +931,7 @@ def _bt_model_fn(
labels=labels,
train_op_fn=_train_op_fn,
logits=logits)
+
# Add an early stop hook.
estimator_spec = estimator_spec._replace(
training_hooks=estimator_spec.training_hooks +
@@ -927,8 +973,8 @@ def _create_classification_head_and_closed_form(n_classes, weight_column,
label_vocabulary):
"""Creates a head for classifier and the closed form gradients/hessians."""
head = _create_classification_head(n_classes, weight_column, label_vocabulary)
- if (n_classes == 2 and head.logits_dimension == 1 and weight_column is None
- and label_vocabulary is None):
+ if (n_classes == 2 and head.logits_dimension == 1 and
+ weight_column is None and label_vocabulary is None):
# Use the closed-form gradients/hessians for 2 class.
def _grad_and_hess_for_logloss(logits, labels):
"""A closed form gradient and hessian for logistic loss."""
@@ -961,8 +1007,282 @@ def _create_regression_head(label_dimension, weight_column=None):
# pylint: enable=protected-access
+def _compute_feature_importances_per_tree(tree, num_features):
+ """Computes the importance of each feature in the tree."""
+ importances = np.zeros(num_features)
+
+ for node in tree.nodes:
+ node_type = node.WhichOneof('node')
+ if node_type == 'bucketized_split':
+ feature_id = node.bucketized_split.feature_id
+ importances[feature_id] += node.metadata.gain
+ elif node_type == 'leaf':
+ assert node.metadata.gain == 0
+ else:
+ raise ValueError('Unexpected split type %s', node_type)
+
+ return importances
+
+
+def _compute_feature_importances(tree_ensemble, num_features, normalize):
+ """Computes gain-based feature importances.
+
+ The higher the value, the more important the feature.
+
+ Args:
+ tree_ensemble: a trained tree ensemble, instance of proto
+ boosted_trees.TreeEnsemble.
+ num_features: The total number of feature ids.
+ normalize: If True, normalize the feature importances.
+
+ Returns:
+ sorted_feature_idx: A list of feature_id which is sorted
+ by its feature importance.
+ feature_importances: A list of corresponding feature importances.
+
+ Raises:
+ AssertionError: When normalize = True, if feature importances
+ contain negative value, or if normalization is not possible
+ (e.g. ensemble is empty or trees contain only a root node).
+ """
+ tree_importances = [_compute_feature_importances_per_tree(tree, num_features)
+ for tree in tree_ensemble.trees]
+ tree_importances = np.array(tree_importances)
+ tree_weights = np.array(tree_ensemble.tree_weights).reshape(-1, 1)
+ feature_importances = np.sum(tree_importances * tree_weights, axis=0)
+ if normalize:
+ assert np.all(feature_importances >= 0), ('feature_importances '
+ 'must be non-negative.')
+ normalizer = np.sum(feature_importances)
+ assert normalizer > 0, 'Trees are all empty or contain only a root node.'
+ feature_importances /= normalizer
+
+ sorted_feature_idx = np.argsort(feature_importances)[::-1]
+ return sorted_feature_idx, feature_importances[sorted_feature_idx]
+
+
+def _bt_explanations_fn(features,
+ head,
+ sorted_feature_columns,
+ name='boosted_trees'):
+ """Gradient Boosted Trees predict with explanations model_fn.
+
+ Args:
+ features: dict of `Tensor`.
+ head: A `head_lib._Head` instance.
+ sorted_feature_columns: Sorted iterable of `feature_column._FeatureColumn`
+ model inputs.
+ name: Name used for the model.
+
+ Returns:
+ An `EstimatorSpec` instance.
+
+ Raises:
+ ValueError: mode or params are invalid, or features has the wrong type.
+ """
+ mode = model_fn_lib.ModeKeys.PREDICT
+ with ops.name_scope(name) as name:
+ # Create Ensemble resources.
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
+
+ input_feature_list = _get_transformed_features(features,
+ sorted_feature_columns)
+
+ logits = boosted_trees_ops.predict(
+ # For non-TRAIN mode, ensemble doesn't change after initialization,
+ # so no local copy is needed; using tree_ensemble directly.
+ tree_ensemble_handle=tree_ensemble.resource_handle,
+ bucketized_features=input_feature_list,
+ logits_dimension=head.logits_dimension)
+
+ estimator_spec = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=None,
+ train_op_fn=control_flow_ops.no_op,
+ logits=logits)
+
+ debug_op = boosted_trees_ops.example_debug_outputs(
+ tree_ensemble.resource_handle,
+ bucketized_features=input_feature_list,
+ logits_dimension=head.logits_dimension)
+ estimator_spec.predictions[boosted_trees_utils._DEBUG_PROTO_KEY] = debug_op # pylint: disable=protected-access
+ return estimator_spec
+
+
+class _BoostedTreesBase(estimator.Estimator):
+ """Base class for boosted trees estimators.
+
+ This class is intended to keep tree-specific functions (E.g., methods for
+ feature importances and directional feature contributions) in one central
+ place.
+
+ It is not a valid (working) Estimator on its own and should only be used as a
+ base class.
+ """
+
+ def __init__(self, model_fn, model_dir, config, feature_columns, head,
+ center_bias, is_classification):
+ """Initializes a `_BoostedTreesBase` instance.
+
+ Args:
+ model_fn: model_fn: Model function. See base class for more detail.
+ model_dir: Directory to save model parameters, graph and etc. See base
+ class for more detail.
+ config: `estimator.RunConfig` configuration object.
+ feature_columns: An iterable containing all the feature columns used by
+ the model. All items in the set should be instances of classes derived
+ from `FeatureColumn`
+ head: A `head_lib._Head` instance.
+ center_bias: Whether bias centering needs to occur. Bias centering refers
+ to the first node in the very first tree returning the prediction that
+ is aligned with the original labels distribution. For example, for
+ regression problems, the first node will return the mean of the labels.
+ For binary classification problems, it will return a logit for a prior
+ probability of label 1.
+ is_classification: If the estimator is for classification.
+ """
+ super(_BoostedTreesBase, self).__init__(
+ model_fn=model_fn, model_dir=model_dir, config=config)
+ self._sorted_feature_columns = sorted(
+ feature_columns, key=lambda tc: tc.name)
+ self._head = head
+ self._n_features = _calculate_num_features(self._sorted_feature_columns)
+ self._names_for_feature_id = np.array(
+ _generate_feature_name_mapping(self._sorted_feature_columns))
+ self._center_bias = center_bias
+ self._is_classification = is_classification
+
+ def experimental_feature_importances(self, normalize=False):
+ """Computes gain-based feature importances.
+
+ The higher the value, the more important the corresponding feature.
+
+ Args:
+ normalize: If True, normalize the feature importances.
+
+ Returns:
+ sorted_feature_names: 1-D array of feature name which is sorted
+ by its feature importance.
+ feature_importances: 1-D array of the corresponding feature importance.
+
+ Raises:
+ ValueError: When attempting to normalize on an empty ensemble
+ or an ensemble of trees which have no splits. Or when attempting
+ to normalize and feature importances have negative values.
+ """
+ reader = checkpoint_utils.load_checkpoint(self._model_dir)
+ serialized = reader.get_tensor('boosted_trees:0_serialized')
+ if not serialized:
+ raise ValueError('Found empty serialized string for TreeEnsemble.'
+ 'You should only call this method after training.')
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+
+ sorted_feature_id, importances = _compute_feature_importances(
+ ensemble_proto, self._n_features, normalize)
+ return self._names_for_feature_id[sorted_feature_id], importances
+
+ def experimental_predict_with_explanations(self,
+ input_fn,
+ predict_keys=None,
+ hooks=None,
+ checkpoint_path=None):
+ """Computes model explainability outputs per example along with predictions.
+
+ Currently supports directional feature contributions (DFCs). For each
+ instance, DFCs indicate the aggregate contribution of each feature. See
+ https://arxiv.org/abs/1312.1121 and
+ http://blog.datadive.net/interpreting-random-forests/ for more details.
+ Args:
+ input_fn: A function that provides input data for predicting as
+ minibatches. See [Premade Estimators](
+ https://tensorflow.org/guide/premade_estimators#create_input_functions)
+ for more information. The function should construct and return one of
+ the following: * A `tf.data.Dataset` object: Outputs of `Dataset`
+ object must be a tuple `(features, labels)` with same constraints as
+ below. * A tuple `(features, labels)`: Where `features` is a `tf.Tensor`
+ or a dictionary of string feature name to `Tensor` and `labels` is a
+ `Tensor` or a dictionary of string label name to `Tensor`. Both
+ `features` and `labels` are consumed by `model_fn`. They should
+ satisfy the expectation of `model_fn` from inputs.
+ predict_keys: list of `str`, name of the keys to predict. It is used if
+ the `tf.estimator.EstimatorSpec.predictions` is a `dict`. If
+ `predict_keys` is used then rest of the predictions will be filtered
+ from the dictionary, with the exception of 'bias' and 'dfc', which will
+ always be in the dictionary. If `None`, returns all keys in prediction
+ dict, as well as two new keys 'dfc' and 'bias'.
+ hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
+ callbacks inside the prediction call.
+ checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
+ latest checkpoint in `model_dir` is used. If there are no checkpoints
+ in `model_dir`, prediction is run with newly initialized `Variables`
+ instead of ones restored from checkpoint.
+
+ Yields:
+ Evaluated values of `predictions` tensors. The `predictions` tensors will
+ contain at least two keys 'dfc' and 'bias' for model explanations. The
+ `dfc` value corresponds to the contribution of each feature to the overall
+ prediction for this instance (positive indicating that the feature makes
+ it more likely to select class 1 and negative less likely). The 'bias'
+ value will be the same across all the instances, corresponding to the
+ probability (classification) or prediction (regression) of the training
+ data distribution.
+
+ Raises:
+ ValueError: when wrong arguments are given or unsupported functionalities
+ are requested.
+ """
+ if not self._center_bias:
+ raise ValueError('center_bias must be enabled during estimator '
+ 'instantiation when using '
+ 'experimental_predict_with_explanations.')
+ # pylint: disable=protected-access
+ if not self._is_classification:
+ identity_inverse_link_fn = self._head._inverse_link_fn in (None,
+ tf_identity)
+ # pylint:enable=protected-access
+ if not identity_inverse_link_fn:
+ raise ValueError(
+ 'For now only identity inverse_link_fn in regression_head is '
+ 'supported for experimental_predict_with_explanations.')
+
+ # pylint:disable=unused-argument
+ def new_model_fn(features, labels, mode):
+ return _bt_explanations_fn(features, self._head,
+ self._sorted_feature_columns)
+
+ # pylint:enable=unused-argument
+ est = estimator.Estimator(
+ model_fn=new_model_fn,
+ model_dir=self.model_dir,
+ config=self.config,
+ warm_start_from=self._warm_start_settings)
+ # Make sure bias and dfc will be in prediction dict.
+ user_supplied_predict_keys = predict_keys is not None
+ if user_supplied_predict_keys:
+ predict_keys = set(predict_keys)
+ predict_keys.add(boosted_trees_utils._DEBUG_PROTO_KEY)
+ predictions = est.predict(
+ input_fn,
+ predict_keys=predict_keys,
+ hooks=hooks,
+ checkpoint_path=checkpoint_path,
+ yield_single_examples=True)
+ for pred in predictions:
+ bias, dfcs = boosted_trees_utils._parse_explanations_from_prediction(
+ pred[boosted_trees_utils._DEBUG_PROTO_KEY], self._n_features,
+ self._is_classification)
+ pred['bias'] = bias
+ pred['dfc'] = dfcs
+ # Don't need to expose serialized proto to end user.
+ del pred[boosted_trees_utils._DEBUG_PROTO_KEY]
+ yield pred
+
+
+# pylint: disable=protected-access
@estimator_export('estimator.BoostedTreesClassifier')
-class BoostedTreesClassifier(estimator.Estimator):
+class BoostedTreesClassifier(_BoostedTreesBase):
"""A Classifier for Tensorflow Boosted Trees models.
@compatibility(eager)
@@ -1082,14 +1402,13 @@ class BoostedTreesClassifier(estimator.Estimator):
n_classes = 2
head, closed_form = _create_classification_head_and_closed_form(
n_classes, weight_column, label_vocabulary=label_vocabulary)
-
# HParams for the model.
tree_hparams = _TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
- return _bt_model_fn( # pylint: disable=protected-access
+ return _bt_model_fn(
features,
labels,
mode,
@@ -1101,11 +1420,17 @@ class BoostedTreesClassifier(estimator.Estimator):
closed_form_grad_and_hess_fn=closed_form)
super(BoostedTreesClassifier, self).__init__(
- model_fn=_model_fn, model_dir=model_dir, config=config)
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=feature_columns,
+ head=head,
+ center_bias=center_bias,
+ is_classification=True)
@estimator_export('estimator.BoostedTreesRegressor')
-class BoostedTreesRegressor(estimator.Estimator):
+class BoostedTreesRegressor(_BoostedTreesBase):
"""A Regressor for Tensorflow Boosted Trees models.
@compatibility(eager)
@@ -1223,9 +1548,17 @@ class BoostedTreesRegressor(estimator.Estimator):
tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
- return _bt_model_fn( # pylint: disable=protected-access
- features, labels, mode, head, feature_columns, tree_hparams,
- n_batches_per_layer, config)
+ return _bt_model_fn(features, labels, mode, head, feature_columns,
+ tree_hparams, n_batches_per_layer, config)
super(BoostedTreesRegressor, self).__init__(
- model_fn=_model_fn, model_dir=model_dir, config=config)
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=feature_columns,
+ head=head,
+ center_bias=center_bias,
+ is_classification=False)
+
+
+# pylint: enable=protected-access
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index 6e28c72151..23687a738b 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -17,9 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+
+from google.protobuf import text_format
import numpy as np
from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator import run_config
@@ -31,10 +35,12 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_boosted_trees_ops
+from tensorflow.python.ops import boosted_trees_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.training import checkpoint_utils
+from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import session_run_hook
NUM_FEATURES = 3
@@ -564,6 +570,704 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
self.assertEqual(1, ensemble.trees[0].nodes[0].bucketized_split.feature_id)
self.assertEqual(0, ensemble.trees[0].nodes[0].bucketized_split.threshold)
+ def testFeatureImportancesWithTrainedEnsemble(self):
+ input_fn = _make_train_input_fn(is_classification=True)
+
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ max_depth=5)
+
+ # It will stop after 5 steps because of the max depth and num trees.
+ num_steps = 100
+ # Train for a few steps, and validate final checkpoint.
+ est.train(input_fn, steps=num_steps)
+
+ feature_names_expected = ['f_0_bucketized',
+ 'f_2_bucketized',
+ 'f_1_bucketized']
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.833933, 0.606342, 0.0], importances)
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=True)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.579010, 0.420990, 0.0], importances)
+
+ def testFeatureImportancesOnEmptyEnsemble(self):
+ input_fn = _make_train_input_fn(is_classification=True)
+
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5)
+
+ class BailOutWithoutTraining(session_run_hook.SessionRunHook):
+
+ def before_run(self, run_context):
+ raise StopIteration('to bail out.')
+
+ # The step-0 checkpoint will have only an empty ensemble.
+ est.train(input_fn,
+ steps=100, # must stop at 0 anyway.
+ hooks=[BailOutWithoutTraining()])
+
+ with self.assertRaisesRegexp(ValueError, 'empty serialized string'):
+ est.experimental_feature_importances(normalize=False)
+
+ with self.assertRaisesRegexp(ValueError, 'empty serialized string'):
+ est.experimental_feature_importances(normalize=True)
+
+ def _create_fake_checkpoint_with_tree_ensemble_proto(self,
+ est,
+ tree_ensemble_text):
+ with ops.Graph().as_default():
+ with ops.name_scope('boosted_trees') as name:
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
+ tree_ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge(tree_ensemble_text, tree_ensemble_proto)
+ stamp_token, _ = tree_ensemble.serialize()
+ restore_op = tree_ensemble.deserialize(
+ stamp_token, tree_ensemble_proto.SerializeToString())
+
+ with session.Session() as sess:
+ resources.initialize_resources(resources.shared_resources()).run()
+ restore_op.run()
+ saver = saver_lib.Saver()
+ save_path = os.path.join(est.model_dir, 'model.ckpt')
+ saver.save(sess, save_path)
+
+ def testFeatureImportancesOnNonEmptyEnsemble(self):
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ max_depth=5)
+
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 2.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 3.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 2.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 7
+ right_id: 8
+ }
+ metadata {
+ gain: 1.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 3.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 1.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 3.34
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 1.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 3.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ feature_names_expected = ['f_0_bucketized',
+ 'f_2_bucketized',
+ 'f_1_bucketized']
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ # Gain sum for each features:
+ # = 1.0 * [3 + 1, 2, 2] + 1.0 * [1, 1, 0]
+ self.assertAllClose([5.0, 3.0, 2.0], importances)
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=True)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.5, 0.3, 0.2], importances)
+
+ def testFeatureImportancesWithTreeWeights(self):
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=3,
+ max_depth=5)
+
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 12.5
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ tree_weights: 0.4
+ tree_weights: 0.6
+ tree_weights: 1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ feature_names_expected = ['f_0_bucketized',
+ 'f_2_bucketized',
+ 'f_1_bucketized']
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ # Gain sum for each features:
+ # = 0.4 * [12.5, 0, 5] + 0.6 * [0, 5, 0] + 1.0 * [0, 0, 0]
+ self.assertAllClose([5.0, 3.0, 2.0], importances)
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=True)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.5, 0.3, 0.2], importances)
+
+ def testFeatureImportancesWithAllEmptyTree(self):
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ max_depth=5)
+
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ # Reverse order because feature importances are sorted by np.argsort(f)[::-1]
+ feature_names_expected = ['f_2_bucketized',
+ 'f_1_bucketized',
+ 'f_0_bucketized']
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.0, 0.0, 0.0], importances)
+
+ with self.assertRaisesRegexp(AssertionError,
+ 'all empty or contain only a root node'):
+ est.experimental_feature_importances(normalize=True)
+
+ def testNegativeFeatureImportances(self):
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5)
+
+ # In order to generate a negative feature importances,
+ # We assign an invalid value -1 to tree_weights here.
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ }
+ tree_weights: -1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ # Github #21509 (nataliaponomareva):
+ # The gains stored in the splits can be negative
+ # if people are using complexity regularization.
+ feature_names_expected = ['f_2_bucketized',
+ 'f_0_bucketized',
+ 'f_1_bucketized']
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.0, 0.0, -5.0], importances)
+
+ with self.assertRaisesRegexp(AssertionError, 'non-negative'):
+ est.experimental_feature_importances(normalize=True)
+
+ def testFeatureImportancesNamesForCategoricalColumn(self):
+ categorical = feature_column.categorical_column_with_vocabulary_list(
+ key='categorical', vocabulary_list=('bad', 'good', 'ok'))
+ feature_indicator = feature_column.indicator_column(categorical)
+ bucketized_col = feature_column.bucketized_column(
+ feature_column.numeric_column(
+ 'continuous', dtype=dtypes.float32),
+ BUCKET_BOUNDARIES)
+ bucketized_indicator = feature_column.indicator_column(bucketized_col)
+
+ est = boosted_trees.BoostedTreesRegressor(
+ feature_columns=[feature_indicator,
+ bucketized_col,
+ bucketized_indicator],
+ n_batches_per_layer=1,
+ n_trees=2,
+ learning_rate=1.0,
+ max_depth=1)
+
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 5.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 4
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 2.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 1.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 5
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 2.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -2.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 3.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 4.34
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ feature_names_expected = ['categorical_indicator:ok',
+ 'continuous_bucketized_indicator:(-2.0, 0.5)',
+ 'continuous_bucketized_indicator:(-inf, -2.0)',
+ 'categorical_indicator:bad',
+ # Reverse order because feature importances
+ # are sorted by np.argsort(f)[::-1]
+ 'continuous_bucketized_indicator:(12.0, inf)',
+ 'continuous_bucketized_indicator:(0.5, 12.0)',
+ 'continuous_bucketized',
+ 'categorical_indicator:good']
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ # Gain sum for each features:
+ # = 1.0 * [5, 0, 2, 0, 0, 0, 0, 0] + 1.0 * [0, 2, 0, 1, 0, 0, 0, 0]
+ self.assertAllClose([5.0, 2.0, 2.0, 1.0, 0.0, 0.0, 0.0, 0.0], importances)
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=True)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.5, 0.2, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0], importances)
+
+ def testFeatureImportancesNamesForUnsupportedColumn(self):
+ numeric_col = feature_column.numeric_column(
+ 'continuous', dtype=dtypes.float32)
+
+ with self.assertRaisesRegexp(ValueError,
+ 'only bucketized_column and indicator_column'):
+ _ = boosted_trees.BoostedTreesRegressor(
+ feature_columns=[numeric_col],
+ n_batches_per_layer=1,
+ n_trees=2,
+ learning_rate=1.0,
+ max_depth=1)
+
+ def testTreeComplexityIsSetCorrectly(self):
+ input_fn = _make_train_input_fn(is_classification=True)
+
+ num_steps = 10
+ # Tree complexity is set but no pruning.
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ tree_complexity=1e-3)
+ with self.assertRaisesRegexp(ValueError, 'Tree complexity have no effect'):
+ est.train(input_fn, steps=num_steps)
+
+ # Pruning but no tree complexity.
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ pruning_mode='pre')
+ with self.assertRaisesRegexp(ValueError,
+ 'tree_complexity must be positive'):
+ est.train(input_fn, steps=num_steps)
+
+ # All is good.
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ pruning_mode='pre',
+ tree_complexity=1e-3)
+ est.train(input_fn, steps=num_steps)
+
+
+class BoostedTreesDebugOutputsTest(test_util.TensorFlowTestCase):
+ """Test debug/model explainability outputs for individual predictions.
+
+ Includes directional feature contributions (DFC).
+ """
+
+ def setUp(self):
+ self._feature_columns = {
+ feature_column.bucketized_column(
+ feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+ BUCKET_BOUNDARIES) for i in range(NUM_FEATURES)
+ }
+
+ def testBinaryClassifierThatDFCIsInPredictions(self):
+ train_input_fn = _make_train_input_fn(is_classification=True)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=3, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ center_bias=True)
+
+ num_steps = 100
+ # Train for a few steps. Validate debug outputs in prediction dicts.
+ est.train(train_input_fn, steps=num_steps)
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn)
+ biases, dfcs = zip(*[(pred['bias'], pred['dfc'])
+ for pred in debug_predictions])
+ self.assertAllClose([0.4] * 5, biases)
+ self.assertAllClose(({
+ 0: -0.12108613453574479,
+ 1: 0.0,
+ 2: -0.039254929814481143
+ }, {
+ 0: 0.19650601422250574,
+ 1: 0.0,
+ 2: 0.02693827052766018
+ }, {
+ 0: 0.16057487356133376,
+ 1: 0.0,
+ 2: 0.02693827052766018
+ }, {
+ 0: -0.12108613453574479,
+ 1: 0.0,
+ 2: -0.039254929814481143
+ }, {
+ 0: -0.10832468554550384,
+ 1: 0.0,
+ 2: 0.02693827052766018
+ }), dfcs)
+
+ # Assert sum(dfcs) + bias == probabilities.
+ expected_probabilities = [
+ 0.23965894, 0.62344426, 0.58751315, 0.23965894, 0.31861359
+ ]
+ probabilities = [
+ sum(dfc.values()) + bias for (dfc, bias) in zip(dfcs, biases)
+ ]
+ self.assertAllClose(expected_probabilities, probabilities)
+
+ # When user doesn't include bias or dfc in predict_keys, make sure to still
+ # include dfc and bias.
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn, predict_keys=['probabilities'])
+ for prediction_dict in debug_predictions:
+ self.assertTrue('bias' in prediction_dict)
+ self.assertTrue('dfc' in prediction_dict)
+ self.assertTrue('probabilities' in prediction_dict)
+ self.assertEqual(len(prediction_dict), 3)
+
+ def testRegressorThatDFCIsInPredictions(self):
+ train_input_fn = _make_train_input_fn(is_classification=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.BoostedTreesRegressor(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ center_bias=True)
+
+ num_steps = 100
+ # Train for a few steps. Validate debug outputs in prediction dicts.
+ est.train(train_input_fn, steps=num_steps)
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn)
+ biases, dfcs = zip(*[(pred['bias'], pred['dfc'])
+ for pred in debug_predictions])
+ self.assertAllClose([1.8] * 5, biases)
+ self.assertAllClose(({
+ 0: -0.070499420166015625,
+ 1: -0.095000028610229492,
+ 2: 0.0
+ }, {
+ 0: -0.53763031959533691,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }, {
+ 0: -0.51756942272186279,
+ 1: -0.095000028610229492,
+ 2: 0.0
+ }, {
+ 0: 0.1563495397567749,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }, {
+ 0: 0.96934974193572998,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }), dfcs)
+
+ # Assert sum(dfcs) + bias == predictions.
+ expected_predictions = [[1.6345005], [1.32570302], [1.1874305],
+ [2.01968288], [2.83268309]]
+ predictions = [
+ [sum(dfc.values()) + bias] for (dfc, bias) in zip(dfcs, biases)
+ ]
+ self.assertAllClose(expected_predictions, predictions)
+
+ # Test when user doesn't include bias or dfc in predict_keys.
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn, predict_keys=['predictions'])
+ for prediction_dict in debug_predictions:
+ self.assertTrue('bias' in prediction_dict)
+ self.assertTrue('dfc' in prediction_dict)
+ self.assertTrue('predictions' in prediction_dict)
+ self.assertEqual(len(prediction_dict), 3)
+
class ModelFnTests(test_util.TensorFlowTestCase):
"""Tests bt_model_fn including unexposed internal functionalities."""
diff --git a/tensorflow/python/estimator/canned/boosted_trees_utils.py b/tensorflow/python/estimator/canned/boosted_trees_utils.py
new file mode 100644
index 0000000000..85efc2304a
--- /dev/null
+++ b/tensorflow/python/estimator/canned/boosted_trees_utils.py
@@ -0,0 +1,80 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Debug and model explainability logic for boosted trees."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+
+# For directional feature contributions.
+_DEBUG_PROTO_KEY = '_serialized_debug_outputs_proto'
+_BIAS_ID = 0
+
+
+def _parse_debug_proto_string(example_proto_serialized):
+ example_debug_outputs = boosted_trees_pb2.DebugOutput()
+ example_debug_outputs.ParseFromString(example_proto_serialized)
+ feature_ids = example_debug_outputs.feature_ids
+ logits_path = example_debug_outputs.logits_path
+ return feature_ids, logits_path
+
+
+def _compute_directional_feature_contributions(example_feature_ids,
+ example_logits_paths, activation,
+ num_bucketized_features):
+ """Directional feature contributions and bias, per example."""
+ # Initialize contributions to 0.
+ dfcs = {k: 0 for k in range(num_bucketized_features)}
+
+ # Traverse tree subtracting child prediction from parent prediction and
+ # associating change with feature id used to split.
+ predictions = np.array(activation(example_logits_paths))
+ delta_pred = predictions[_BIAS_ID + 1:] - predictions[:-1]
+ # Group by feature id, then sum delta_pred.
+ contribs = np.bincount(
+ example_feature_ids,
+ weights=delta_pred,
+ minlength=num_bucketized_features)
+ for f, dfc in zip(range(num_bucketized_features), contribs):
+ dfcs[f] = dfc
+ return predictions[_BIAS_ID], dfcs
+
+
+def _identity(logits):
+ return logits
+
+
+def _sigmoid(logits):
+ # TODO(crawles): Change to softmax once multiclass support is available.
+ return 1 / (1 + np.exp(-np.array(logits)))
+
+
+def _parse_explanations_from_prediction(serialized_debug_proto,
+ n_features,
+ classification=False):
+ """Parse serialized explanability proto, compute dfc, and return bias, dfc."""
+ feature_ids, logits_path = _parse_debug_proto_string(serialized_debug_proto)
+ if classification:
+ activation = _sigmoid
+ else:
+ activation = _identity
+ bias, dfcs = _compute_directional_feature_contributions(
+ feature_ids, logits_path, activation, n_features)
+ # TODO(crawles): Prediction path and leaf IDs.
+ return bias, dfcs
diff --git a/tensorflow/python/estimator/canned/boosted_trees_utils_test.py b/tensorflow/python/estimator/canned/boosted_trees_utils_test.py
new file mode 100644
index 0000000000..506d4ea6fb
--- /dev/null
+++ b/tensorflow/python/estimator/canned/boosted_trees_utils_test.py
@@ -0,0 +1,187 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests boosted_trees estimators and model_fn."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.estimator.canned import boosted_trees_utils
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class BoostedTreesDFCTest(test_util.TensorFlowTestCase):
+ """Test directional feature contributions (DFC) helper functions. """
+
+ def testDirectionalFeatureContributionsCompute(self):
+ """Tests logic to compute DFCs given feature ids and logits paths."""
+ num_bucketized_features = 3 # Includes one unused feature.
+ examples_feature_ids = ((2, 2, 0, 0), (2, 2, 0))
+ e1_feature_ids, e2_feature_ids = examples_feature_ids
+
+ # DFCs are computed by traversing the prediction path and subtracting each
+ # child prediction from its parent prediction and associating the change in
+ # prediction with the respective feature id used for the split.
+ # For each activation function, f, (currently identity or sigmoid), DFCs are
+ # calculated for the two examples as:
+ # example 1:
+ # feature_0 = (f(1.114) - f(1.214)) + (f(6.114) - f(1.114))
+ # feature_1 = 0 # Feature not in ensemble, thus zero contrib.
+ # feature_2 = (f(0.114) - bias_pred) + (f(1.214) - f(0.114))
+ # example 2:
+ # feature_0 = f(-5.486) - f(1.514)
+ # feature_1 = 0 # Feature not in ensemble, thus zero contrib.
+ # feature_2 = (f(0.114) - bias_pred) + (f(1.514) - f(0.114))
+ # where bias_pred is = f(0) or f(0.21), with center_bias = {True, False},
+ # respectively.
+ # Keys are center_bias.
+ expected_dfcs_identity = {
+ False: ({
+ 0: 4.9,
+ 1: 0,
+ 2: 1.214
+ }, {
+ 0: -7.0,
+ 1: 0,
+ 2: 1.514
+ }),
+ True: ({
+ 0: 4.9,
+ 1: 0,
+ 2: 1.0039999999999998
+ }, {
+ 0: -7.0,
+ 1: 0,
+ 2: 1.3039999999999998
+ })
+ }
+ expected_dfcs_sigmoid = {
+ False: ({
+ 0: 0.22678725678805578,
+ 1: 0,
+ 2: 0.2710059376234506
+ }, {
+ 0: -0.81552596670046507,
+ 1: 0,
+ 2: 0.319653250251275
+ }),
+ True: ({
+ 0: 0.22678725678805578,
+ 1: 0,
+ 2: 0.2186980280491253
+ }, {
+ 0: -0.81552596670046507,
+ 1: 0,
+ 2: 0.26734534067694971
+ })
+ }
+ # pylint: disable=protected-access
+ for f, expected_dfcs in zip(
+ (boosted_trees_utils._identity, boosted_trees_utils._sigmoid),
+ (expected_dfcs_identity, expected_dfcs_sigmoid)):
+ for center_bias in [False, True]:
+ # If not center_bias, the bias after activation is 0.
+ if center_bias:
+ bias_logit = 0.21 # Root node of tree_0.
+ else:
+ bias_logit = 0 # 0 is default value when there is no original_leaf.
+ f_bias = f(bias_logit)
+
+ # Logits before and after, as is outputed from
+ # boosted_trees_ops.example_debug_outputs
+ examples_logits_paths = ((bias_logit, 0.114, 1.214, 1.114, 6.114),
+ (bias_logit, 0.114, 1.514, -5.486))
+ e1_logits_path, e2_logits_path = examples_logits_paths
+ e1_expected_dfcs, e2_expected_dfcs = expected_dfcs[center_bias]
+ # Check feature contributions are correct for both examples.
+ # Example 1.
+ # pylint:disable=line-too-long
+ e1_bias, e1_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+ e1_feature_ids, e1_logits_path, f, num_bucketized_features)
+ self.assertAllClose(e1_bias, f_bias)
+ self.assertAllClose(e1_dfc, e1_expected_dfcs)
+ # Example 2.
+ e2_bias, e2_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+ e2_feature_ids, e2_logits_path, f, num_bucketized_features)
+ # pylint:enable=line-too-long
+ self.assertAllClose(e2_bias, f_bias)
+ self.assertAllClose(e2_dfc, e2_expected_dfcs)
+ # Check if contributions sum to final prediction.
+ # For each tree, get leaf of last tree.
+ expected_logits = (e1_logits_path[-1], e2_logits_path[-1])
+ # Predictions should be the sum of contributions + bias.
+ expected_preds = [f(logit) for logit in expected_logits]
+ e1_pred = e1_bias + sum(e1_dfc.values())
+ e2_pred = e2_bias + sum(e2_dfc.values())
+ preds = [e1_pred, e2_pred]
+ self.assertAllClose(preds, expected_preds)
+ # pylint: enable=protected-access
+
+ def testDFCComputeComparedToExternalExample(self):
+ """Tests `compute_dfc` compared to external example (regression).
+
+ Example from http://blog.datadive.net/interpreting-random-forests.
+ """
+ # DIS:3, RM: 2, LSTAT:1, NOX:0
+ num_bucketized_features = 4
+ e1_feature_ids = (2, 1, 0)
+ e2_feature_ids = (2, 2, 2)
+ e3_feature_ids = (2, 2, 0)
+
+ bias_logit = 22.60 # Root node of tree_0.
+ activation = boosted_trees_utils._identity
+ f_bias = activation(bias_logit)
+ # Logits before and after, as is outputed from
+ # boosted_trees_ops.example_debug_outputs
+ e1_logits_path = (bias_logit, 19.96, 14.91, 18.11)
+ e2_logits_path = (bias_logit, 37.42, 45.10, 45.90)
+ e3_logits_path = (bias_logit, 37.42, 32.30, 33.58)
+ e1_expected_dfcs = {0: 3.20, 1: -5.05, 2: -2.64, 3: 0}
+ e2_expected_dfcs = {0: 0, 1: 0, 2: 23.3, 3: 0}
+ e3_expected_dfcs = {0: 1.28, 1: 0, 2: 9.7, 3: 0}
+ # Check feature contributions are correct for both examples.
+ # Example 1.
+ # pylint: disable=protected-access
+ # pylint: disable=line-too-long
+ e1_bias, e1_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+ e1_feature_ids, e1_logits_path, activation, num_bucketized_features)
+ self.assertAllClose(e1_bias, f_bias)
+ self.assertAllClose(e1_dfc, e1_expected_dfcs)
+ # Example 2.
+ e2_bias, e2_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+ e2_feature_ids, e2_logits_path, activation, num_bucketized_features)
+ self.assertAllClose(e2_bias, f_bias)
+ self.assertAllClose(e2_dfc, e2_expected_dfcs)
+ # Example 3.
+ e3_bias, e3_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+ e3_feature_ids, e3_logits_path, activation, num_bucketized_features)
+ # pylint: enable=line-too-long
+ self.assertAllClose(e3_bias, f_bias)
+ self.assertAllClose(e3_dfc, e3_expected_dfcs)
+ # pylint: enable=protected-access
+ # Check if contributions sum to final prediction.
+ # For each tree, get leaf of last tree.
+ expected_logits = (18.11, 45.90, 33.58)
+ # Predictions should be the sum of contributions + bias.
+ expected_preds = [activation(logit) for logit in expected_logits]
+ e1_pred = e1_bias + sum(e1_dfc.values())
+ e2_pred = e2_bias + sum(e2_dfc.values())
+ e3_pred = e3_bias + sum(e3_dfc.values())
+ preds = [e1_pred, e2_pred, e3_pred]
+ self.assertAllClose(preds, expected_preds)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py
index 1c0c4581c0..97971f9561 100644
--- a/tensorflow/python/estimator/canned/dnn.py
+++ b/tensorflow/python/estimator/canned/dnn.py
@@ -24,7 +24,10 @@ from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import optimizers
-from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
+from tensorflow.python.framework import ops
+from tensorflow.python.keras.engine import training
from tensorflow.python.layers import core as core_layers
from tensorflow.python.layers import normalization
from tensorflow.python.ops import init_ops
@@ -45,8 +48,14 @@ def _add_hidden_layer_summary(value, tag):
summary.histogram('%s/activation' % tag, value)
-def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn,
- dropout, input_layer_partitioner, batch_norm):
+def _dnn_logit_fn_builder(units,
+ hidden_units,
+ feature_columns,
+ activation_fn,
+ dropout,
+ input_layer_partitioner,
+ batch_norm,
+ shared_state_manager=None):
"""Function builder for a dnn logit_fn.
Args:
@@ -60,6 +69,8 @@ def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn,
coordinate.
input_layer_partitioner: Partitioner for input layer.
batch_norm: Whether to use batch normalization after each hidden layer.
+ shared_state_manager: A SharedEmbeddingStateManager object to hold the
+ shared state for SharedEmbeddingColumn's.
Returns:
A logit_fn (see below).
@@ -85,50 +96,132 @@ def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn,
A `Tensor` representing the logits, or a list of `Tensor`'s representing
multiple logits in the MultiHead case.
"""
- is_training = mode == model_fn.ModeKeys.TRAIN
- with variable_scope.variable_scope(
- 'input_from_feature_columns',
- values=tuple(six.itervalues(features)),
- partitioner=input_layer_partitioner):
- net = feature_column_lib.input_layer(
- features=features, feature_columns=feature_columns)
+ dnn_model = _DNNModel(
+ units,
+ hidden_units,
+ feature_columns,
+ activation_fn,
+ dropout,
+ input_layer_partitioner,
+ batch_norm,
+ shared_state_manager,
+ name='dnn')
+ return dnn_model(features, mode)
+
+ return dnn_logit_fn
+
+
+def _get_previous_name_scope():
+ current_name_scope = ops.get_name_scope()
+ return current_name_scope.rsplit('/', 1)[0] + '/'
+
+
+class _DNNModel(training.Model):
+ """A DNN Model."""
+
+ def __init__(self,
+ units,
+ hidden_units,
+ feature_columns,
+ activation_fn,
+ dropout,
+ input_layer_partitioner,
+ batch_norm,
+ shared_state_manager,
+ name=None,
+ **kwargs):
+ super(_DNNModel, self).__init__(name=name, **kwargs)
+ self._is_v2 = False
+ if feature_column_v2.is_feature_column_v2(feature_columns):
+ self._is_v2 = True
+ self._input_layer = feature_column_v2.FeatureLayer(
+ feature_columns=feature_columns,
+ name='input_layer',
+ shared_state_manager=shared_state_manager)
+ else:
+ self._input_layer = feature_column.InputLayer(
+ feature_columns=feature_columns,
+ name='input_layer',
+ create_scope_now=False)
+
+ self._add_layer(self._input_layer, 'input_layer')
+
+ self._dropout = dropout
+ self._batch_norm = batch_norm
+
+ self._hidden_layers = []
+ self._dropout_layers = []
+ self._batch_norm_layers = []
+ self._hidden_layer_scope_names = []
for layer_id, num_hidden_units in enumerate(hidden_units):
with variable_scope.variable_scope(
- 'hiddenlayer_%d' % layer_id, values=(net,)) as hidden_layer_scope:
- net = core_layers.dense(
- net,
+ 'hiddenlayer_%d' % layer_id) as hidden_layer_scope:
+ hidden_layer = core_layers.Dense(
units=num_hidden_units,
activation=activation_fn,
kernel_initializer=init_ops.glorot_uniform_initializer(),
- name=hidden_layer_scope)
- if dropout is not None and is_training:
- net = core_layers.dropout(net, rate=dropout, training=True)
- if batch_norm:
- # TODO(hjm): In future, if this becomes popular, we can enable
- # customization of the batch normalization params by accepting a
- # list of `BatchNormalization` instances as `batch_norm`.
- net = normalization.batch_normalization(
- net,
+ name=hidden_layer_scope,
+ _scope=hidden_layer_scope)
+ self._add_layer(hidden_layer, hidden_layer_scope.name)
+ self._hidden_layer_scope_names.append(hidden_layer_scope.name)
+ self._hidden_layers.append(hidden_layer)
+ if self._dropout is not None:
+ dropout_layer = core_layers.Dropout(rate=self._dropout)
+ self._add_layer(dropout_layer, dropout_layer.name)
+ self._dropout_layers.append(dropout_layer)
+ if self._batch_norm:
+ batch_norm_layer = normalization.BatchNormalization(
# The default momentum 0.99 actually crashes on certain
# problem, so here we use 0.999, which is the default of
# tf.contrib.layers.batch_norm.
momentum=0.999,
- training=is_training,
- name='batchnorm_%d' % layer_id)
- _add_hidden_layer_summary(net, hidden_layer_scope.name)
-
- with variable_scope.variable_scope('logits', values=(net,)) as logits_scope:
- logits = core_layers.dense(
- net,
+ trainable=True,
+ name='batchnorm_%d' % layer_id,
+ _scope='batchnorm_%d' % layer_id)
+ self._add_layer(batch_norm_layer, batch_norm_layer.name)
+ self._batch_norm_layers.append(batch_norm_layer)
+
+ with variable_scope.variable_scope('logits') as logits_scope:
+ self._logits_layer = core_layers.Dense(
units=units,
activation=None,
kernel_initializer=init_ops.glorot_uniform_initializer(),
- name=logits_scope)
- _add_hidden_layer_summary(logits, logits_scope.name)
-
- return logits
-
- return dnn_logit_fn
+ name=logits_scope,
+ _scope=logits_scope)
+ self._add_layer(self._logits_layer, logits_scope.name)
+ self._logits_scope_name = logits_scope.name
+ self._logits_layer._use_resource_variables = False # pylint: disable=protected-access
+ self._input_layer_partitioner = input_layer_partitioner
+
+ def call(self, features, mode):
+ is_training = mode == model_fn.ModeKeys.TRAIN
+ # The Keras training.Model adds a name_scope with the name of the model
+ # which modifies the constructed graph. Hence we add another name_scope
+ # here which is the one before the training.Model one was applied.
+ # TODO(rohanj): Remove this in TF 2.0 (b/116728605)
+ with ops.name_scope(name=_get_previous_name_scope()):
+ # TODO(rohanj): Remove dependence on variable scope for partitioning.
+ with variable_scope.variable_scope(
+ 'input_from_feature_columns',
+ partitioner=self._input_layer_partitioner):
+ net = self._input_layer(features)
+ for i in range(len(self._hidden_layers)):
+ net = self._hidden_layers[i](net)
+ if self._dropout is not None and is_training:
+ net = self._dropout_layers[i](net, training=True)
+ if self._batch_norm:
+ net = self._batch_norm_layers[i](net, training=is_training)
+ _add_hidden_layer_summary(net, self._hidden_layer_scope_names[i])
+
+ logits = self._logits_layer(net)
+ _add_hidden_layer_summary(logits, self._logits_scope_name)
+ return logits
+
+ def _add_layer(self, layer, layer_name):
+ # "Magic" required for keras.Model classes to track all the variables in
+ # a list of layers.Layer objects.
+ # TODO(ashankar): Figure out API so user code doesn't have to do this.
+ setattr(self, layer_name, layer)
def _dnn_model_fn(features,
@@ -143,7 +236,8 @@ def _dnn_model_fn(features,
input_layer_partitioner=None,
config=None,
use_tpu=False,
- batch_norm=False):
+ batch_norm=False,
+ shared_state_manager=None):
"""Deep Neural Net model_fn.
Args:
@@ -167,6 +261,8 @@ def _dnn_model_fn(features,
use_tpu: Whether to make a DNN model able to run on TPU. Will make function
return a `_TPUEstimatorSpec` instance and disable variable partitioning.
batch_norm: Whether to use batch normalization after each hidden layer.
+ shared_state_manager: A SharedEmbeddingStateManager object to hold the
+ shared state for SharedEmbeddingColumn's.
Returns:
An `EstimatorSpec` instance.
@@ -202,7 +298,8 @@ def _dnn_model_fn(features,
activation_fn=activation_fn,
dropout=dropout,
input_layer_partitioner=input_layer_partitioner,
- batch_norm=batch_norm)
+ batch_norm=batch_norm,
+ shared_state_manager=shared_state_manager)
logits = logit_fn(features=features, mode=mode)
if use_tpu:
@@ -370,6 +467,10 @@ class DNNClassifier(estimator.Estimator):
"""
head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access
n_classes, weight_column, label_vocabulary, loss_reduction)
+
+ shared_state_manager = feature_column_v2.maybe_create_shared_state_manager(
+ feature_columns)
+
def _model_fn(features, labels, mode, config):
"""Call the defined shared _dnn_model_fn."""
return _dnn_model_fn(
@@ -384,7 +485,8 @@ class DNNClassifier(estimator.Estimator):
dropout=dropout,
input_layer_partitioner=input_layer_partitioner,
config=config,
- batch_norm=batch_norm)
+ batch_norm=batch_norm,
+ shared_state_manager=shared_state_manager)
super(DNNClassifier, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config,
@@ -532,6 +634,10 @@ class DNNRegressor(estimator.Estimator):
batch_norm: Whether to use batch normalization after each hidden layer.
"""
+ shared_state_manager = None
+ if feature_column_v2.is_feature_column_v2(feature_columns):
+ shared_state_manager = feature_column_v2.SharedEmbeddingStateManager()
+
def _model_fn(features, labels, mode, config):
"""Call the defined shared _dnn_model_fn."""
return _dnn_model_fn(
@@ -539,7 +645,8 @@ class DNNRegressor(estimator.Estimator):
labels=labels,
mode=mode,
head=head_lib._regression_head( # pylint: disable=protected-access
- label_dimension=label_dimension, weight_column=weight_column,
+ label_dimension=label_dimension,
+ weight_column=weight_column,
loss_reduction=loss_reduction),
hidden_units=hidden_units,
feature_columns=tuple(feature_columns or []),
@@ -548,7 +655,8 @@ class DNNRegressor(estimator.Estimator):
dropout=dropout,
input_layer_partitioner=input_layer_partitioner,
config=config,
- batch_norm=batch_norm)
+ batch_norm=batch_norm,
+ shared_state_manager=shared_state_manager)
super(DNNRegressor, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config,
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py
index 9799cf9e98..f712244c8d 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py
@@ -27,6 +27,7 @@ from tensorflow.python.estimator.canned import dnn
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import linear
from tensorflow.python.estimator.canned import optimizers
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import nn
@@ -142,6 +143,9 @@ def _dnn_linear_combined_model_fn(features,
max_partitions=num_ps_replicas,
min_slice_size=64 << 20))
+ shared_state_manager = feature_column_v2.maybe_create_shared_state_manager(
+ list(linear_feature_columns) + list(dnn_feature_columns))
+
# Build DNN Logits.
dnn_parent_scope = 'dnn'
@@ -169,8 +173,9 @@ def _dnn_linear_combined_model_fn(features,
feature_columns=dnn_feature_columns,
activation_fn=dnn_activation_fn,
dropout=dnn_dropout,
+ batch_norm=batch_norm,
input_layer_partitioner=input_layer_partitioner,
- batch_norm=batch_norm)
+ shared_state_manager=shared_state_manager)
dnn_logits = dnn_logit_fn(features=features, mode=mode)
linear_parent_scope = 'linear'
diff --git a/tensorflow/python/estimator/canned/dnn_test.py b/tensorflow/python/estimator/canned/dnn_test.py
index fc90b7c35e..756696cea0 100644
--- a/tensorflow/python/estimator/canned/dnn_test.py
+++ b/tensorflow/python/estimator/canned/dnn_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import shutil
import tempfile
+from absl.testing import parameterized
import numpy as np
import six
@@ -33,6 +34,7 @@ from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.estimator.inputs import pandas_io
from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import data_flow_ops
@@ -62,15 +64,32 @@ class DNNModelFnTest(dnn_testing_utils.BaseDNNModelFnTest, test.TestCase):
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
- dnn_testing_utils.BaseDNNModelFnTest.__init__(self, dnn._dnn_model_fn)
+ dnn_testing_utils.BaseDNNModelFnTest.__init__(
+ self, dnn._dnn_model_fn, fc_impl=feature_column)
+
+
+class DNNModelFnV2Test(dnn_testing_utils.BaseDNNModelFnTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNModelFnTest.__init__(
+ self, dnn._dnn_model_fn, fc_impl=feature_column_v2)
class DNNLogitFnTest(dnn_testing_utils.BaseDNNLogitFnTest, test.TestCase):
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
- dnn_testing_utils.BaseDNNLogitFnTest.__init__(self,
- dnn._dnn_logit_fn_builder)
+ dnn_testing_utils.BaseDNNLogitFnTest.__init__(
+ self, dnn._dnn_logit_fn_builder, fc_impl=feature_column)
+
+
+class DNNLogitFnV2Test(dnn_testing_utils.BaseDNNLogitFnTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNLogitFnTest.__init__(
+ self, dnn._dnn_logit_fn_builder, fc_impl=feature_column_v2)
class DNNWarmStartingTest(dnn_testing_utils.BaseDNNWarmStartingTest,
@@ -78,8 +97,17 @@ class DNNWarmStartingTest(dnn_testing_utils.BaseDNNWarmStartingTest,
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
- dnn_testing_utils.BaseDNNWarmStartingTest.__init__(self, _dnn_classifier_fn,
- _dnn_regressor_fn)
+ dnn_testing_utils.BaseDNNWarmStartingTest.__init__(
+ self, _dnn_classifier_fn, _dnn_regressor_fn, fc_impl=feature_column)
+
+
+class DNNWarmStartingV2Test(dnn_testing_utils.BaseDNNWarmStartingTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNWarmStartingTest.__init__(
+ self, _dnn_classifier_fn, _dnn_regressor_fn, fc_impl=feature_column_v2)
class DNNClassifierEvaluateTest(
@@ -88,7 +116,16 @@ class DNNClassifierEvaluateTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNClassifierEvaluateV2Test(
+ dnn_testing_utils.BaseDNNClassifierEvaluateTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
class DNNClassifierPredictTest(
@@ -97,7 +134,16 @@ class DNNClassifierPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNClassifierPredictV2Test(dnn_testing_utils.BaseDNNClassifierPredictTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
class DNNClassifierTrainTest(
@@ -106,7 +152,16 @@ class DNNClassifierTrainTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNClassifierTrainV2Test(dnn_testing_utils.BaseDNNClassifierTrainTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
def _dnn_regressor_fn(*args, **kwargs):
@@ -119,7 +174,16 @@ class DNNRegressorEvaluateTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+
+
+class DNNRegressorEvaluateV2Test(dnn_testing_utils.BaseDNNRegressorEvaluateTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
class DNNRegressorPredictTest(
@@ -128,7 +192,16 @@ class DNNRegressorPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+
+
+class DNNRegressorPredictV2Test(dnn_testing_utils.BaseDNNRegressorPredictTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
class DNNRegressorTrainTest(
@@ -137,7 +210,16 @@ class DNNRegressorTrainTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+
+
+class DNNRegressorTrainV2Test(dnn_testing_utils.BaseDNNRegressorTrainTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
def _queue_parsed_features(feature_map):
@@ -156,7 +238,8 @@ def _queue_parsed_features(feature_map):
return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))}
-class DNNRegressorIntegrationTest(test.TestCase):
+@parameterized.parameters((feature_column,), (feature_column_v2,))
+class DNNRegressorIntegrationTest(test.TestCase, parameterized.TestCase):
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -166,11 +249,11 @@ class DNNRegressorIntegrationTest(test.TestCase):
writer_cache.FileWriterCache.clear()
shutil.rmtree(self._model_dir)
- def _test_complete_flow(
- self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
- label_dimension, batch_size):
- feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, label_dimension, batch_size,
+ fc_impl):
+ feature_columns = [fc_impl.numeric_column('x', shape=(input_dimension,))]
+
est = dnn.DNNRegressor(
hidden_units=(2, 2),
feature_columns=feature_columns,
@@ -194,14 +277,14 @@ class DNNRegressorIntegrationTest(test.TestCase):
self.assertAllEqual((batch_size, label_dimension), predictions.shape)
# EXPORT
- feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ feature_spec = fc_impl.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
serving_input_receiver_fn)
self.assertTrue(gfile.Exists(export_dir))
- def test_numpy_input_fn(self):
+ def test_numpy_input_fn(self, fc_impl):
"""Tests complete flow with numpy_input_fn."""
label_dimension = 2
batch_size = 10
@@ -230,9 +313,10 @@ class DNNRegressorIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_pandas_input_fn(self):
+ def test_pandas_input_fn(self, fc_impl):
"""Tests complete flow with pandas_input_fn."""
if not HAS_PANDAS:
return
@@ -263,9 +347,10 @@ class DNNRegressorIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_input_fn_from_parse_example(self):
+ def test_input_fn_from_parse_example(self, fc_impl):
"""Tests complete flow with input_fn constructed from parse_example."""
label_dimension = 2
batch_size = 10
@@ -313,9 +398,11 @@ class DNNRegressorIntegrationTest(test.TestCase):
predict_input_fn=_predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
+@parameterized.parameters((feature_column,), (feature_column_v2,))
class DNNClassifierIntegrationTest(test.TestCase):
def setUp(self):
@@ -329,11 +416,10 @@ class DNNClassifierIntegrationTest(test.TestCase):
def _as_label(self, data_in_float):
return np.rint(data_in_float).astype(np.int64)
- def _test_complete_flow(
- self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
- n_classes, batch_size):
- feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, n_classes, batch_size, fc_impl):
+ feature_columns = [fc_impl.numeric_column('x', shape=(input_dimension,))]
+
est = dnn.DNNClassifier(
hidden_units=(2, 2),
feature_columns=feature_columns,
@@ -357,14 +443,14 @@ class DNNClassifierIntegrationTest(test.TestCase):
self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
# EXPORT
- feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ feature_spec = fc_impl.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
serving_input_receiver_fn)
self.assertTrue(gfile.Exists(export_dir))
- def test_numpy_input_fn(self):
+ def test_numpy_input_fn(self, fc_impl):
"""Tests complete flow with numpy_input_fn."""
n_classes = 3
input_dimension = 2
@@ -396,9 +482,10 @@ class DNNClassifierIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_pandas_input_fn(self):
+ def test_pandas_input_fn(self, fc_impl):
"""Tests complete flow with pandas_input_fn."""
if not HAS_PANDAS:
return
@@ -430,9 +517,10 @@ class DNNClassifierIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_input_fn_from_parse_example(self):
+ def test_input_fn_from_parse_example(self, fc_impl):
"""Tests complete flow with input_fn constructed from parse_example."""
input_dimension = 2
n_classes = 3
@@ -484,7 +572,8 @@ class DNNClassifierIntegrationTest(test.TestCase):
predict_input_fn=_predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
if __name__ == '__main__':
diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py
index 11f1e93630..cd66d0a3bd 100644
--- a/tensorflow/python/estimator/canned/dnn_testing_utils.py
+++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py
@@ -104,6 +104,7 @@ def create_checkpoint(weights_and_biases,
weights_and_biases: Iterable of tuples of weight and bias values.
global_step: Initial global step to save in checkpoint.
model_dir: Directory into which checkpoint is saved.
+ batch_norm_vars: Variables used for batch normalization.
"""
weights, biases = zip(*weights_and_biases)
if batch_norm_vars:
@@ -244,8 +245,9 @@ def mock_optimizer(testcase, hidden_units, expected_loss=None):
class BaseDNNModelFnTest(object):
"""Tests that _dnn_model_fn passes expected logits to mock head."""
- def __init__(self, dnn_model_fn):
+ def __init__(self, dnn_model_fn, fc_impl=feature_column):
self._dnn_model_fn = dnn_model_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -272,7 +274,7 @@ class BaseDNNModelFnTest(object):
head=head,
hidden_units=hidden_units,
feature_columns=[
- feature_column.numeric_column(
+ self._fc_impl.numeric_column(
'age', shape=np.array(inputs).shape[1:])
],
optimizer=mock_optimizer(self, hidden_units))
@@ -462,8 +464,8 @@ class BaseDNNModelFnTest(object):
head=head,
hidden_units=hidden_units,
feature_columns=[
- feature_column.numeric_column('age'),
- feature_column.numeric_column('height')
+ self._fc_impl.numeric_column('age'),
+ self._fc_impl.numeric_column('height')
],
optimizer=mock_optimizer(self, hidden_units))
with monitored_session.MonitoredTrainingSession(
@@ -499,7 +501,7 @@ class BaseDNNModelFnTest(object):
head=head,
hidden_units=hidden_units,
feature_columns=[
- feature_column.numeric_column(
+ self._fc_impl.numeric_column(
'age', shape=np.array(inputs).shape[1:])
],
optimizer=mock_optimizer(self, hidden_units))
@@ -508,8 +510,9 @@ class BaseDNNModelFnTest(object):
class BaseDNNLogitFnTest(object):
"""Tests correctness of logits calculated from _dnn_logit_fn_builder."""
- def __init__(self, dnn_logit_fn_builder):
+ def __init__(self, dnn_logit_fn_builder, fc_impl=feature_column):
self._dnn_logit_fn_builder = dnn_logit_fn_builder
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -541,7 +544,7 @@ class BaseDNNLogitFnTest(object):
units=logits_dimension,
hidden_units=hidden_units,
feature_columns=[
- feature_column.numeric_column(
+ self._fc_impl.numeric_column(
'age', shape=np.array(inputs).shape[1:])
],
activation_fn=nn.relu,
@@ -786,8 +789,8 @@ class BaseDNNLogitFnTest(object):
units=logits_dimension,
hidden_units=hidden_units,
feature_columns=[
- feature_column.numeric_column('age'),
- feature_column.numeric_column('height')
+ self._fc_impl.numeric_column('age'),
+ self._fc_impl.numeric_column('height')
],
activation_fn=nn.relu,
dropout=None,
@@ -806,9 +809,13 @@ class BaseDNNLogitFnTest(object):
class BaseDNNWarmStartingTest(object):
- def __init__(self, _dnn_classifier_fn, _dnn_regressor_fn):
+ def __init__(self,
+ _dnn_classifier_fn,
+ _dnn_regressor_fn,
+ fc_impl=feature_column):
self._dnn_classifier_fn = _dnn_classifier_fn
self._dnn_regressor_fn = _dnn_regressor_fn
+ self._fc_impl = fc_impl
def setUp(self):
# Create a directory to save our old checkpoint and vocabularies to.
@@ -843,8 +850,8 @@ class BaseDNNWarmStartingTest(object):
def test_classifier_basic_warm_starting(self):
"""Tests correctness of DNNClassifier default warm-start."""
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ city = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
@@ -875,8 +882,8 @@ class BaseDNNWarmStartingTest(object):
def test_regressor_basic_warm_starting(self):
"""Tests correctness of DNNRegressor default warm-start."""
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ city = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
@@ -905,8 +912,8 @@ class BaseDNNWarmStartingTest(object):
def test_warm_starting_selective_variables(self):
"""Tests selecting variables to warm-start."""
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ city = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
@@ -958,8 +965,8 @@ class BaseDNNWarmStartingTest(object):
vocab_file = os.path.join(self._ckpt_and_vocab_dir, 'occupation_vocab')
with open(vocab_file, 'w') as f:
f.write('\n'.join(vocab_list))
- occupation = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_file(
+ occupation = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_file(
'occupation',
vocabulary_file=vocab_file,
vocabulary_size=len(vocab_list)),
@@ -985,8 +992,8 @@ class BaseDNNWarmStartingTest(object):
'new_occupation_vocab')
with open(new_vocab_file, 'w') as f:
f.write('\n'.join(new_vocab_list))
- new_occupation = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_file(
+ new_occupation = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_file(
'occupation',
vocabulary_file=new_vocab_file,
vocabulary_size=len(new_vocab_list)),
@@ -1051,8 +1058,8 @@ class BaseDNNWarmStartingTest(object):
def test_warm_starting_with_naming_change(self):
"""Tests warm-starting with a Tensor name remapping."""
- locality = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ locality = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_list(
'locality', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
@@ -1068,8 +1075,8 @@ class BaseDNNWarmStartingTest(object):
# Create a second DNNClassifier, warm-started from the first. Use a
# learning_rate = 0.0 optimizer to check values (use SGD so we don't have
# accumulator values that change).
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ city = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
warm_started_dnn_classifier = self._dnn_classifier_fn(
@@ -1101,8 +1108,9 @@ class BaseDNNWarmStartingTest(object):
class BaseDNNClassifierEvaluateTest(object):
- def __init__(self, dnn_classifier_fn):
+ def __init__(self, dnn_classifier_fn, fc_impl=feature_column):
self._dnn_classifier_fn = dnn_classifier_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1121,7 +1129,7 @@ class BaseDNNClassifierEvaluateTest(object):
dnn_classifier = self._dnn_classifier_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age')],
+ feature_columns=[self._fc_impl.numeric_column('age')],
model_dir=self._model_dir)
def _input_fn():
# batch_size = 2, one false label, and one true.
@@ -1161,7 +1169,7 @@ class BaseDNNClassifierEvaluateTest(object):
dnn_classifier = self._dnn_classifier_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age', shape=[2])],
+ feature_columns=[self._fc_impl.numeric_column('age', shape=[2])],
n_classes=n_classes,
model_dir=self._model_dir)
def _input_fn():
@@ -1192,7 +1200,7 @@ class BaseDNNClassifierEvaluateTest(object):
dnn_classifier = self._dnn_classifier_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age')],
+ feature_columns=[self._fc_impl.numeric_column('age')],
model_dir=self._model_dir)
def _input_fn():
# batch_size = 2, one false label, and one true.
@@ -1218,7 +1226,7 @@ class BaseDNNClassifierEvaluateTest(object):
dnn_classifier = self._dnn_classifier_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age', shape=[2])],
+ feature_columns=[self._fc_impl.numeric_column('age', shape=[2])],
n_classes=n_classes,
weight_column='w',
model_dir=self._model_dir)
@@ -1238,8 +1246,9 @@ class BaseDNNClassifierEvaluateTest(object):
class BaseDNNRegressorEvaluateTest(object):
- def __init__(self, dnn_regressor_fn):
+ def __init__(self, dnn_regressor_fn, fc_impl=feature_column):
self._dnn_regressor_fn = dnn_regressor_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1259,7 +1268,7 @@ class BaseDNNRegressorEvaluateTest(object):
dnn_regressor = self._dnn_regressor_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age')],
+ feature_columns=[self._fc_impl.numeric_column('age')],
model_dir=self._model_dir)
def _input_fn():
return {'age': [[10.]]}, [[1.]]
@@ -1289,7 +1298,7 @@ class BaseDNNRegressorEvaluateTest(object):
dnn_regressor = self._dnn_regressor_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age', shape=[2])],
+ feature_columns=[self._fc_impl.numeric_column('age', shape=[2])],
label_dimension=label_dimension,
model_dir=self._model_dir)
def _input_fn():
@@ -1320,7 +1329,7 @@ class BaseDNNRegressorEvaluateTest(object):
dnn_regressor = self._dnn_regressor_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age', shape=[2])],
+ feature_columns=[self._fc_impl.numeric_column('age', shape=[2])],
label_dimension=label_dimension,
weight_column='w',
model_dir=self._model_dir)
@@ -1339,8 +1348,9 @@ class BaseDNNRegressorEvaluateTest(object):
class BaseDNNClassifierPredictTest(object):
- def __init__(self, dnn_classifier_fn):
+ def __init__(self, dnn_classifier_fn, fc_impl=feature_column):
self._dnn_classifier_fn = dnn_classifier_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1361,7 +1371,7 @@ class BaseDNNClassifierPredictTest(object):
dnn_classifier = self._dnn_classifier_fn(
hidden_units=(2, 2),
label_vocabulary=label_vocabulary,
- feature_columns=(feature_column.numeric_column('x'),),
+ feature_columns=(self._fc_impl.numeric_column('x'),),
model_dir=self._model_dir)
input_fn = numpy_io.numpy_input_fn(
x={'x': np.array([[10.]])}, batch_size=1, shuffle=False)
@@ -1405,7 +1415,7 @@ class BaseDNNClassifierPredictTest(object):
dnn_classifier = self._dnn_classifier_fn(
hidden_units=(2, 2),
- feature_columns=(feature_column.numeric_column('x', shape=(2,)),),
+ feature_columns=(self._fc_impl.numeric_column('x', shape=(2,)),),
label_vocabulary=label_vocabulary,
n_classes=3,
model_dir=self._model_dir)
@@ -1453,8 +1463,9 @@ class BaseDNNClassifierPredictTest(object):
class BaseDNNRegressorPredictTest(object):
- def __init__(self, dnn_regressor_fn):
+ def __init__(self, dnn_regressor_fn, fc_impl=feature_column):
self._dnn_regressor_fn = dnn_regressor_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1475,7 +1486,7 @@ class BaseDNNRegressorPredictTest(object):
dnn_regressor = self._dnn_regressor_fn(
hidden_units=(2, 2),
- feature_columns=(feature_column.numeric_column('x'),),
+ feature_columns=(self._fc_impl.numeric_column('x'),),
model_dir=self._model_dir)
input_fn = numpy_io.numpy_input_fn(
x={'x': np.array([[10.]])}, batch_size=1, shuffle=False)
@@ -1497,7 +1508,7 @@ class BaseDNNRegressorPredictTest(object):
dnn_regressor = self._dnn_regressor_fn(
hidden_units=(2, 2),
- feature_columns=(feature_column.numeric_column('x', shape=(2,)),),
+ feature_columns=(self._fc_impl.numeric_column('x', shape=(2,)),),
label_dimension=3,
model_dir=self._model_dir)
input_fn = numpy_io.numpy_input_fn(
@@ -1594,8 +1605,9 @@ def _assert_simple_summary(testcase, expected_values, actual_summary):
class BaseDNNClassifierTrainTest(object):
- def __init__(self, dnn_classifier_fn):
+ def __init__(self, dnn_classifier_fn, fc_impl=feature_column):
self._dnn_classifier_fn = dnn_classifier_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1609,7 +1621,7 @@ class BaseDNNClassifierTrainTest(object):
hidden_units = (2, 2)
dnn_classifier = self._dnn_classifier_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
model_dir=self._model_dir)
# Train for a few steps, then validate final checkpoint.
@@ -1625,7 +1637,7 @@ class BaseDNNClassifierTrainTest(object):
n_classes = 3
dnn_classifier = self._dnn_classifier_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
@@ -1643,7 +1655,7 @@ class BaseDNNClassifierTrainTest(object):
self, hidden_units=hidden_units)
dnn_classifier = self._dnn_classifier_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
optimizer=opt,
model_dir=self._model_dir)
self.assertEqual(0, opt.minimize.call_count)
@@ -1682,7 +1694,7 @@ class BaseDNNClassifierTrainTest(object):
self, hidden_units=hidden_units, expected_loss=expected_loss)
dnn_classifier = self._dnn_classifier_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
optimizer=opt,
model_dir=self._model_dir)
self.assertEqual(0, opt.minimize.call_count)
@@ -1728,7 +1740,7 @@ class BaseDNNClassifierTrainTest(object):
self, hidden_units=hidden_units, expected_loss=expected_loss)
dnn_classifier = self._dnn_classifier_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
optimizer=opt,
model_dir=self._model_dir)
self.assertEqual(0, opt.minimize.call_count)
@@ -1759,7 +1771,7 @@ class BaseDNNClassifierTrainTest(object):
dnn_classifier = self._dnn_classifier_fn(
n_classes=n_classes,
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
optimizer=opt,
model_dir=self._model_dir)
self.assertEqual(0, opt.minimize.call_count)
@@ -1793,8 +1805,9 @@ class BaseDNNClassifierTrainTest(object):
class BaseDNNRegressorTrainTest(object):
- def __init__(self, dnn_regressor_fn):
+ def __init__(self, dnn_regressor_fn, fc_impl=feature_column):
self._dnn_regressor_fn = dnn_regressor_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1808,7 +1821,7 @@ class BaseDNNRegressorTrainTest(object):
hidden_units = (2, 2)
dnn_regressor = self._dnn_regressor_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
model_dir=self._model_dir)
# Train for a few steps, then validate final checkpoint.
@@ -1824,7 +1837,7 @@ class BaseDNNRegressorTrainTest(object):
opt = mock_optimizer(self, hidden_units=hidden_units)
dnn_regressor = self._dnn_regressor_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
optimizer=opt,
model_dir=self._model_dir)
self.assertEqual(0, opt.minimize.call_count)
@@ -1864,7 +1877,7 @@ class BaseDNNRegressorTrainTest(object):
self, hidden_units=hidden_units, expected_loss=expected_loss)
dnn_regressor = self._dnn_regressor_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
optimizer=opt,
model_dir=self._model_dir)
self.assertEqual(0, opt.minimize.call_count)
@@ -1917,7 +1930,8 @@ class BaseDNNRegressorTrainTest(object):
dnn_regressor = self._dnn_regressor_fn(
hidden_units=hidden_units,
feature_columns=[
- feature_column.numeric_column('age', shape=[input_dimension])],
+ self._fc_impl.numeric_column('age', shape=[input_dimension])
+ ],
label_dimension=label_dimension,
optimizer=opt,
model_dir=self._model_dir)
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 90280fd25d..b933cedb99 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -41,7 +41,6 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_util
-from tensorflow.python.keras import metrics
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import metrics as metrics_lib
@@ -145,7 +144,7 @@ class Estimator(object):
* `labels`: This is the second item returned from the `input_fn`
passed to `train`, `evaluate`, and `predict`. This should be a
single `tf.Tensor` or `dict` of same (for multi-head models).
- If mode is @{tf.estimator.ModeKeys.PREDICT}, `labels=None` will
+ If mode is `tf.estimator.ModeKeys.PREDICT`, `labels=None` will
be passed. If the `model_fn`'s signature does not accept
`mode`, the `model_fn` must still be able to handle
`labels=None`.
@@ -475,11 +474,31 @@ class Estimator(object):
return _evaluate()
def _convert_eval_steps_to_hooks(self, steps):
+ """Create hooks to run correct number of steps in evaluation.
+
+ Args:
+ steps: number of steps to run during evaluation.
+
+ Raises:
+ ValueError: if steps is less than or equal to zero.
+
+ Returns:
+ List of hooks to be passed to the estimator.
+ """
if steps is None:
return []
if steps <= 0:
raise ValueError('Must specify steps > 0, given: {}'.format(steps))
+
+ # The hooks are declared as private in evaluation.py discourage the use
+ # by other libraries or open source users. This should be the only usage
+ # of the estimator evaluation hooks.
+ if self._eval_distribution:
+ steps_per_run = getattr(self._eval_distribution, 'steps_per_run', 1)
+ if steps_per_run > 1:
+ return [evaluation._MultiStepStopAfterNEvalsHook( # pylint: disable=protected-access
+ num_evals=steps, steps_per_run=steps_per_run)]
return [evaluation._StopAfterNEvalsHook(num_evals=steps)] # pylint: disable=protected-access
def predict(self,
@@ -490,6 +509,10 @@ class Estimator(object):
yield_single_examples=True):
"""Yields predictions for given features.
+ Please note that interleaving two predict outputs does not work. See:
+ [issue/20506](
+ https://github.com/tensorflow/tensorflow/issues/20506#issuecomment-422208517)
+
Args:
input_fn: A function that constructs the features. Prediction continues
until `input_fn` raises an end-of-input exception
@@ -611,7 +634,7 @@ class Estimator(object):
# pylint: disable=line-too-long,g-doc-args,g-doc-return-or-yield
"""Exports inference graph as a `SavedModel` into the given dir.
- Note that `export_to_savedmodel` will be renamed to `export_to_saved_model`
+ Note that `export_to_savedmodel` will be renamed to `export_saved_model`
in TensorFlow 2.0. At that time, `export_to_savedmodel` without the
additional underscore will be available only through tf.compat.v1.
@@ -696,7 +719,7 @@ class Estimator(object):
"""
# pylint: enable=line-too-long
# TODO(b/111442174): `export_to_savedmodel` will be renamed to
- # `export_to_saved_model` in TensorFlow 2.0. This function is a wrapper
+ # `export_saved_model` in TensorFlow 2.0. This function is a wrapper
# while staging the new version; do not add any logic here.
return self.export_savedmodel(
export_dir_base,
@@ -780,9 +803,9 @@ class Estimator(object):
those features and labels, and restores the given checkpoint
(or, lacking that, the most recent checkpoint) into the graph.
Only one of the modes is used for saving variables to the `SavedModel`
- (order of preference: @{tf.estimator.ModeKeys#TRAIN$TRAIN},
- @{tf.estimator.ModeKeys#EVAL$EVAL}, then
- @{tf.estimator.ModeKeys#PREDICT$PREDICT}), such that up to three
+ (order of preference: `tf.estimator.ModeKeys.TRAIN`,
+ `tf.estimator.ModeKeys.EVAL`, then
+ `tf.estimator.ModeKeys.PREDICT`), such that up to three
`tf.MetaGraphDefs` are saved with a single set of variables in a single
`SavedModel` directory.
@@ -1078,7 +1101,7 @@ class Estimator(object):
"""Creates the global step tensor in graph.
The global step tensor must be an integer type with name 'global_step' and
- be added to the collection @{tf.GraphKeys#GLOBAL_STEP$GLOBAL_STEP}.
+ be added to the collection `tf.GraphKeys.GLOBAL_STEP`.
Args:
graph: The graph in which to create the global step tensor.
@@ -1471,6 +1494,7 @@ class Estimator(object):
self._eval_distribution.__class__.__name__ == 'TPUStrategy')
if is_tpu_strategy:
+ steps_per_run_variable = training.get_or_create_steps_per_run_variable()
def step_fn(ctx, features, labels=None):
"""Runs one step of the eval computation and captures outputs."""
estimator_spec = self._eval_distribution.call_for_each_tower(
@@ -1487,7 +1511,7 @@ class Estimator(object):
# TODO(priyag): Fix eval step hook to account for steps_per_run.
ctx = self._eval_distribution.run_steps_on_dataset(
- step_fn, iterator, iterations=self._eval_distribution.steps_per_run)
+ step_fn, iterator, iterations=steps_per_run_variable)
update_op = ctx.run_op
eval_dict = ctx.non_tensor_outputs['eval_dict']
grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec']
@@ -1653,7 +1677,7 @@ def _combine_distributed_scaffold(grouped_scaffold, distribution):
def _unwrap_and_concat(value):
value = nest.flatten(distribution.unwrap(value))
if len(value) != 1:
- return array_ops.concat(value)
+ return array_ops.concat(value, 0)
return value[0]
ready_op = distribution.call_for_each_tower(
@@ -1788,18 +1812,9 @@ def _extract_metric_update_ops(eval_dict, distribution=None):
value_ops = {}
# Sort metrics lexicographically so graph is identical every time.
for name, value in sorted(six.iteritems(eval_dict)):
- if isinstance(value, metrics.Metric):
- metric_result = value.result()
- # We expect only one update op for every metric when there is no
- # distribution strategy.
- metric_update = value.updates if distribution else value.updates[0]
- else:
- metric_result = value[0]
- metric_update = value[1]
-
- value_ops[name] = metric_result
+ value_ops[name] = value[0]
update_ops.append(
- distribution.group(metric_update) if distribution else metric_update)
+ distribution.group(value[1]) if distribution else value[1])
update_op = control_flow_ops.group(*update_ops) if update_ops else None
return update_op, value_ops
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 1ed5e30b0e..bc2504ca19 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -1017,7 +1017,7 @@ class EstimatorGetVariablesTest(test.TestCase):
def _model_fn(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='one')
+ variables.VariableV1(1., name='one')
return model_fn_lib.EstimatorSpec(
mode=mode,
loss=constant_op.constant(0.),
@@ -1033,8 +1033,8 @@ class EstimatorGetVariablesTest(test.TestCase):
def _model_fn(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='one')
- variables.Variable(3., name='three')
+ variables.VariableV1(1., name='one')
+ variables.VariableV1(3., name='three')
return model_fn_lib.EstimatorSpec(
mode=mode,
loss=constant_op.constant(0.),
@@ -1178,13 +1178,13 @@ class EstimatorEvaluateTest(test.TestCase):
def _model_fn(features, labels, mode, params):
del features, labels, params
mean = metrics_module.Mean()
- mean.update_state(variables.Variable(2.) + 1)
+ mean.update_state(variables.VariableV1(2.) + 1)
return model_fn_lib.EstimatorSpec(
mode,
loss=constant_op.constant(1.),
eval_metric_ops={
'mean1': mean,
- 'mean2': metrics_lib.mean(variables.Variable(2.) + 1)
+ 'mean2': metrics_lib.mean(variables.VariableV1(2.) + 1)
})
est = estimator.Estimator(model_fn=_model_fn)
@@ -1332,7 +1332,7 @@ class EstimatorEvaluateTest(test.TestCase):
def _model_fn_with_incremental_loss(features, labels, mode):
_, _ = features, labels
- local_weight = variables.Variable(
+ local_weight = variables.VariableV1(
0., name='local_weight', collections=[ops.GraphKeys.LOCAL_VARIABLES])
# Loss will be 2, 4, 6, ...
loss = 2 * state_ops.assign_add(local_weight, 1.)
@@ -1385,7 +1385,7 @@ class EstimatorEvaluateTest(test.TestCase):
def _get_model_fn(val=1):
def _model_fn(features, labels, mode):
del features, labels # unused
- variables.Variable(val, name='weight')
+ variables.VariableV1(val, name='weight')
return model_fn_lib.EstimatorSpec(
mode=mode,
predictions=constant_op.constant([[1.]]),
@@ -1409,7 +1409,7 @@ class EstimatorEvaluateTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
self.mock_saver = get_mock_saver()
return model_fn_lib.EstimatorSpec(
mode=mode,
@@ -1603,7 +1603,7 @@ class EstimatorPredictTest(test.TestCase):
def test_no_checkpoint_uses_init(self):
def _model_fn(features, labels, mode, params, config):
del features, labels, params, config
- x = variables.Variable([[3.]], name='x')
+ x = variables.VariableV1([[3.]], name='x')
return model_fn_lib.EstimatorSpec(mode, predictions=math_ops.add(x, 1.))
est = estimator.Estimator(model_fn=_model_fn)
# Expected prediction value is 1 + the value of the Variable that is newly
@@ -1614,7 +1614,7 @@ class EstimatorPredictTest(test.TestCase):
def _make_model_fn(x):
def _variable_creating_and_export_model_fn(features, labels, mode):
_, _ = features, labels
- x_var = variables.Variable([[x]], name='x')
+ x_var = variables.VariableV1([[x]], name='x')
return model_fn_lib.EstimatorSpec(
mode,
predictions=math_ops.add(x_var, 1.),
@@ -1936,7 +1936,7 @@ class EstimatorPredictTest(test.TestCase):
def _model_fn(features, labels, mode):
_, _ = features, labels
- v = variables.Variable([[16.]], name='weight')
+ v = variables.VariableV1([[16.]], name='weight')
prediction = v * 2
return model_fn_lib.EstimatorSpec(
mode,
@@ -1953,7 +1953,7 @@ class EstimatorPredictTest(test.TestCase):
def _model_fn(features, labels, mode):
_, _ = features, labels
- v = variables.Variable([[16.]], name='weight')
+ v = variables.VariableV1([[16.]], name='weight')
prediction = v * 2
return model_fn_lib.EstimatorSpec(
mode,
@@ -1974,7 +1974,7 @@ class EstimatorPredictTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
self.mock_saver = get_mock_saver()
return model_fn_lib.EstimatorSpec(
mode=mode,
@@ -2029,7 +2029,7 @@ class EstimatorPredictTest(test.TestCase):
def _model_fn_for_export_tests(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
scores = constant_op.constant([3.])
classes = constant_op.constant(['wumpus'])
update_global_step = state_ops.assign_add(training.get_global_step(), 1)
@@ -2052,11 +2052,11 @@ def _x_y_input_fn():
def _model_fn_with_x_y(features, labels, mode):
_ = labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
scores = constant_op.constant([3.])
classes = constant_op.constant(['wumpus'])
if mode == model_fn_lib.ModeKeys.PREDICT:
- variables.Variable(36., name='name_collision')
+ variables.VariableV1(36., name='name_collision')
return model_fn_lib.EstimatorSpec(
mode,
predictions=constant_op.constant(10.),
@@ -2076,8 +2076,8 @@ def _model_fn_with_x_y(features, labels, mode):
metrics_lib.mean(
features['x'] - features['y'], name='{}mean'.format(prefix))
}
- variables.Variable(1., name='later_var')
- variables.Variable(3., name='name_collision')
+ variables.VariableV1(1., name='later_var')
+ variables.VariableV1(3., name='name_collision')
return model_fn_lib.EstimatorSpec(
mode,
predictions=multiplied,
@@ -2411,9 +2411,9 @@ class EstimatorExportTest(test.TestCase):
def _model_fn_with_predict_only_vars(features, labels, mode):
_, _ = features, labels
if mode == model_fn_lib.ModeKeys.PREDICT:
- variables.Variable(1., name='only_in_predict')
+ variables.VariableV1(1., name='only_in_predict')
else:
- variables.Variable(1., name='otherwise')
+ variables.VariableV1(1., name='otherwise')
prediction = constant_op.constant(1.)
return model_fn_lib.EstimatorSpec(
@@ -2684,7 +2684,7 @@ class EstimatorExportTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
self.mock_saver = get_mock_saver()
scores = constant_op.constant([3.])
return model_fn_lib.EstimatorSpec(
@@ -2717,7 +2717,7 @@ class EstimatorExportTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
scores = constant_op.constant([3.])
if mode == model_fn_lib.ModeKeys.PREDICT:
@@ -2762,8 +2762,8 @@ class EstimatorExportTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
- my_int = variables.Variable(1, name='my_int',
- collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ my_int = variables.VariableV1(1, name='my_int',
+ collections=[ops.GraphKeys.LOCAL_VARIABLES])
_ = training.get_or_create_steps_per_run_variable()
scores = constant_op.constant([3.])
with ops.control_dependencies([
@@ -2808,8 +2808,8 @@ class EstimatorExportTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
- my_int = variables.Variable(1, name='my_int',
- collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ my_int = variables.VariableV1(1, name='my_int',
+ collections=[ops.GraphKeys.LOCAL_VARIABLES])
scores = constant_op.constant([3.])
with ops.control_dependencies([
variables.local_variables_initializer(),
@@ -3038,7 +3038,7 @@ class EstimatorExportTest(test.TestCase):
def _model_fn(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
return model_fn_lib.EstimatorSpec(
mode,
predictions=constant_op.constant(10.),
@@ -3081,7 +3081,7 @@ class EstimatorHookOrderingTest(test.TestCase):
"""A graph that generates NaN's for testing."""
del features, labels
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, name='global_step')
inc_global_step = state_ops.assign_add(global_step, 1)
nan_const = constant_op.constant(np.nan, dtype=dtypes.float32)
diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py
index 3eed1ab163..ed3219c49b 100644
--- a/tensorflow/python/estimator/export/export_test.py
+++ b/tensorflow/python/estimator/export/export_test.py
@@ -376,7 +376,7 @@ class ExportTest(test_util.TensorFlowTestCase):
" } "
"} ", example)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_result = sess.run(
serving_input_receiver.features,
feed_dict={
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index 6b2765be82..7546771ed3 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import os
import re
+import six
from tensorflow.python.client import session
from tensorflow.python.estimator import estimator as estimator_lib
@@ -31,6 +32,7 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import metrics
from tensorflow.python.keras import models
from tensorflow.python.keras import optimizers
from tensorflow.python.ops import check_ops
@@ -214,25 +216,40 @@ def _convert_keras_metrics_to_estimator(model):
if not getattr(model, 'metrics', None):
return None
- # TODO(psv/fchollet): support stateful metrics
eval_metric_ops = {}
+
+ def get_metric_name(metric):
+ if isinstance(metric, metrics.Metric):
+ return metric.name
+ if callable(metric):
+ return metric.__name__
+ assert isinstance(metric, six.string_types)
+ return metric
+
# When each metric maps to an output
if isinstance(model.metrics, dict):
for i, output_name in enumerate(model.metrics.keys()):
- metric_name = model.metrics[output_name]
- if callable(metric_name):
- metric_name = metric_name.__name__
+ # `metric` is the user given metric value in `compile`. This can be
+ # metric name (`acc`), metric function (binary_accuracy) or a metric
+ # object (BinaryAccuracy()).
+ metric = model.metrics[output_name]
+ metric_name = get_metric_name(metric)
# When some outputs use the same metric
if list(model.metrics.values()).count(metric_name) > 1:
metric_name += '_' + output_name
- eval_metric_ops[metric_name] = metrics_module.mean(
- model.metrics_tensors[i - len(model.metrics)])
+ if isinstance(metric, metrics.Metric):
+ eval_metric_ops[metric_name] = metric
+ else:
+ eval_metric_ops[metric_name] = metrics_module.mean(
+ model.metrics_tensors[i - len(model.metrics)])
else:
- for i, metric_name in enumerate(model.metrics):
- if callable(metric_name):
- metric_name = metric_name.__name__
- eval_metric_ops[metric_name] = metrics_module.mean(
- model.metrics_tensors[i])
+ for i, metric in enumerate(model.metrics):
+ metric_name = get_metric_name(metric)
+ if isinstance(metric, metrics.Metric):
+ eval_metric_ops[metric_name] = metric
+ else:
+ eval_metric_ops[metric_name] = metrics_module.mean(
+ model.metrics_tensors[i])
return eval_metric_ops
diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py
index 3758243d7b..288f9b8906 100644
--- a/tensorflow/python/estimator/keras_test.py
+++ b/tensorflow/python/estimator/keras_test.py
@@ -257,7 +257,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
est_keras = keras_lib.model_to_estimator(
@@ -281,7 +281,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
my_hook = MyHook()
with self.cached_session():
@@ -306,7 +306,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
my_hook = MyHook()
with self.cached_session():
keras_model.fit(x_train, y_train, epochs=1)
@@ -328,7 +328,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
est_keras = keras_lib.model_to_estimator(
@@ -351,7 +351,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
est_keras = keras_lib.model_to_estimator(
@@ -370,7 +370,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
# Create state
@@ -662,7 +662,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
tf_config = json.dumps({
'cluster': {
@@ -687,7 +687,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3)
sess_config = config_pb2.ConfigProto(gpu_options=gpu_options)
@@ -706,7 +706,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
est_keras = keras_lib.model_to_estimator(
@@ -736,7 +736,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR):
@@ -751,7 +751,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
with self.assertRaisesRegexp(ValueError, '`model_dir` are set both in '
@@ -765,7 +765,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
keras_model.train_on_batch(
np.random.random((10,) + _INPUT_SIZE),
@@ -776,7 +776,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=SGD(lr=0.0001, momentum=0.9),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
@@ -786,7 +786,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=optimizer,
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session() as sess:
keras_model_fn = keras_lib._create_keras_model_fn(keras_model)
global_step = training_util.create_global_step()
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index 439cc2e3a4..824789467d 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -308,6 +308,8 @@ class EstimatorSpec(
for key, value in six.iteritems(eval_metric_ops):
if isinstance(value, Metric):
vars_to_add.update(value.variables)
+ # Convert Metric instances to (value_tensor, update_op) tuple.
+ eval_metric_ops[key] = (value.result(), value.updates[0])
# Remove variables that are in the local variables collection already.
vars_to_add = vars_to_add.difference(local_vars)
for v in vars_to_add:
@@ -466,13 +468,13 @@ class _TPUEstimatorSpec(
def _check_is_tensor_or_operation(x, name):
- if not (isinstance(x, ops.Operation) or isinstance(x, ops.Tensor)):
+ if not (isinstance(x, ops.Operation) or ops.is_dense_tensor_like(x)):
raise TypeError('{} must be Operation or Tensor, given: {}'.format(name, x))
def _check_is_tensor(x, tensor_name):
"""Returns `x` if it is a `Tensor`, raises TypeError otherwise."""
- if not isinstance(x, ops.Tensor):
+ if not ops.is_dense_tensor_like(x):
raise TypeError('{} must be Tensor, given: {}'.format(tensor_name, x))
return x
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index ac53a84eef..5800b693b4 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -156,7 +156,7 @@ py_test(
"//tensorflow/python:variables",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
- "//tensorflow/python/estimator:numpy_io",
+ "//tensorflow/python/estimator:estimator_py",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 9984379e9d..618e70f3a5 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -170,7 +170,8 @@ def _internal_input_layer(features,
trainable=True,
cols_to_vars=None,
scope=None,
- cols_to_output_tensors=None):
+ cols_to_output_tensors=None,
+ from_template=False):
"""See input_layer. `scope` is a name or variable scope to use."""
feature_columns = _normalize_feature_columns(feature_columns)
@@ -186,10 +187,7 @@ def _internal_input_layer(features,
if ops.GraphKeys.MODEL_VARIABLES not in weight_collections:
weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
- # a non-None `scope` can allow for variable reuse, when, e.g., this function
- # is wrapped by a `make_template`.
- with variable_scope.variable_scope(
- scope, default_name='input_layer', values=features.values()):
+ def _get_logits(): # pylint: disable=missing-docstring
builder = _LazyBuilder(features)
output_tensors = []
ordered_columns = []
@@ -217,6 +215,16 @@ def _internal_input_layer(features,
_verify_static_batch_size_equality(output_tensors, ordered_columns)
return array_ops.concat(output_tensors, 1)
+ # If we're constructing from the `make_template`, that by default adds a
+ # variable scope with the name of the layer. In that case, we dont want to
+ # add another `variable_scope` as that would break checkpoints.
+ if from_template:
+ return _get_logits()
+ else:
+ with variable_scope.variable_scope(
+ scope, default_name='input_layer', values=features.values()):
+ return _get_logits()
+
@tf_export('feature_column.input_layer')
def input_layer(features,
@@ -301,17 +309,18 @@ class InputLayer(object):
feature_columns,
weight_collections=None,
trainable=True,
- cols_to_vars=None):
+ cols_to_vars=None,
+ name='feature_column_input_layer',
+ create_scope_now=True):
"""See `input_layer`."""
self._feature_columns = feature_columns
self._weight_collections = weight_collections
self._trainable = trainable
self._cols_to_vars = cols_to_vars
+ self._name = name
self._input_layer_template = template.make_template(
- 'feature_column_input_layer',
- _internal_input_layer,
- create_scope_now_=True)
+ self._name, _internal_input_layer, create_scope_now_=create_scope_now)
self._scope = self._input_layer_template.variable_scope
def __call__(self, features):
@@ -321,7 +330,11 @@ class InputLayer(object):
weight_collections=self._weight_collections,
trainable=self._trainable,
cols_to_vars=None,
- scope=self._scope)
+ from_template=True)
+
+ @property
+ def name(self):
+ return self._name
@property
def non_trainable_variables(self):
@@ -2305,7 +2318,7 @@ class _LazyBuilder(object):
# Input_tensor must have rank 1.
if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
return sparse_ops.sparse_reshape(
- input_tensor, [array_ops.shape(input_tensor)[0], -1])
+ input_tensor, [array_ops.shape(input_tensor)[0], 1])
else:
return array_ops.expand_dims(input_tensor, -1)
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index abb79efa68..1ae510250c 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -169,6 +169,18 @@ class LazyColumnTest(test.TestCase):
TypeError, '"key" must be either a "str" or "_FeatureColumn".'):
builder.get(NotAFeatureColumn())
+ def test_expand_dim_rank_1_sparse_tensor_empty_batch(self):
+ # empty 1-D sparse tensor:
+ builder = _LazyBuilder(features={'a': sparse_tensor.SparseTensor(
+ indices=np.reshape(np.array([], dtype=np.int64), (0, 1)),
+ dense_shape=[0],
+ values=np.array([]))})
+ with self.cached_session():
+ spv = builder.get('a').eval()
+ self.assertAllEqual(np.array([0, 1], dtype=np.int64), spv.dense_shape)
+ self.assertAllEqual(
+ np.reshape(np.array([], dtype=np.int64), (0, 2)), spv.indices)
+
class NumericColumnTest(test.TestCase):
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index 28c5c82d2c..a8d5bfb437 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -136,14 +136,11 @@ import six
from tensorflow.python.eager import context
-from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.keras.engine import training
from tensorflow.python.keras.engine.base_layer import Layer
-from tensorflow.python.layers import base
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
@@ -153,7 +150,6 @@ from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
@@ -245,28 +241,19 @@ class StateManager(object):
raise NotImplementedError('StateManager.get_resource')
-class _InputLayerStateManager(StateManager):
- """Manages the state of InputLayer."""
+class _StateManagerImpl(StateManager):
+ """Manages the state of FeatureLayer and LinearModel."""
- def __init__(self, layer, feature_columns, trainable):
- """Creates an _InputLayerStateManager object.
+ def __init__(self, layer, trainable):
+ """Creates an _StateManagerImpl object.
Args:
layer: The input layer this state manager is associated with.
- feature_columns: List of feature columns for the input layer
trainable: Whether by default, variables created are trainable or not.
"""
self._trainable = trainable
self._layer = layer
- self._cols_to_vars_map = {}
- self._cols_to_names_map = {}
- for column in sorted(feature_columns, key=lambda x: x.name):
- self._cols_to_vars_map[column] = {}
- base_name = column.name
- if isinstance(column, SharedEmbeddingColumn):
- base_name = column.shared_collection_name
- with variable_scope.variable_scope(base_name) as vs:
- self._cols_to_names_map[column] = _strip_leading_slashes(vs.name)
+ self._cols_to_vars_map = collections.defaultdict(lambda: {})
def create_variable(self,
feature_column,
@@ -277,19 +264,19 @@ class _InputLayerStateManager(StateManager):
initializer=None):
if name in self._cols_to_vars_map[feature_column]:
raise ValueError('Variable already exists.')
- with variable_scope.variable_scope(self._cols_to_names_map[feature_column]):
- var = self._layer.add_variable(
- name=name,
- shape=shape,
- dtype=dtype,
- initializer=initializer,
- trainable=self._trainable and trainable,
- # TODO(rohanj): Get rid of this hack once we have a mechanism for
- # specifying a default partitioner for an entire layer. In that case,
- # the default getter for Layers should work.
- getter=variable_scope.get_variable)
- self._cols_to_vars_map[feature_column][name] = var
- return var
+
+ var = self._layer.add_variable(
+ name=name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ trainable=self._trainable and trainable,
+ # TODO(rohanj): Get rid of this hack once we have a mechanism for
+ # specifying a default partitioner for an entire layer. In that case,
+ # the default getter for Layers should work.
+ getter=variable_scope.get_variable)
+ self._cols_to_vars_map[feature_column][name] = var
+ return var
def get_variable(self, feature_column, name):
if name in self._cols_to_vars_map[feature_column]:
@@ -313,12 +300,15 @@ class FeatureLayer(Layer):
keywords_embedded = embedding_column(
categorical_column_with_hash_bucket("keywords", 10K), dimensions=16)
columns = [price, keywords_embedded, ...]
- features = tf.parse_example(..., features=make_parse_example_spec(columns))
feature_layer = FeatureLayer(columns)
+
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
dense_tensor = feature_layer(features)
for units in [128, 64, 32]:
dense_tensor = tf.layers.dense(dense_tensor, units, tf.nn.relu)
- prediction = tf.layers.dense(dense_tensor, 1)."""
+ prediction = tf.layers.dense(dense_tensor, 1).
+ ```
+ """
def __init__(self,
feature_columns,
@@ -375,8 +365,7 @@ class FeatureLayer(Layer):
super(FeatureLayer, self).__init__(name=name, trainable=trainable, **kwargs)
self._feature_columns = _normalize_feature_columns(feature_columns)
- self._state_manager = _InputLayerStateManager(self, self._feature_columns,
- self.trainable)
+ self._state_manager = _StateManagerImpl(self, self.trainable)
self._shared_state_manager = shared_state_manager
for column in sorted(self._feature_columns, key=lambda x: x.name):
if not isinstance(column, DenseColumn):
@@ -395,7 +384,8 @@ class FeatureLayer(Layer):
column.create_state(self._shared_state_manager)
else:
with variable_scope.variable_scope(None, default_name=self.name):
- column.create_state(self._state_manager)
+ with variable_scope.variable_scope(None, default_name=column.name):
+ column.create_state(self._state_manager)
super(FeatureLayer, self).build(None)
def call(self, features, cols_to_output_tensors=None):
@@ -448,20 +438,18 @@ class FeatureLayer(Layer):
return (input_shape[0], total_elements)
-def linear_model(features,
- feature_columns,
- units=1,
- sparse_combiner='sum',
- weight_collections=None,
- trainable=True,
- cols_to_vars=None):
- """Returns a linear prediction `Tensor` based on given `feature_columns`.
+def _strip_leading_slashes(name):
+ return name.rsplit('/', 1)[-1]
+
- This function generates a weighted sum based on output dimension `units`.
+class LinearModel(Layer):
+ """Produces a linear prediction `Tensor` based on given `feature_columns`.
+
+ This layer generates a weighted sum based on output dimension `units`.
Weighted sum refers to logits in classification problems. It refers to the
prediction itself for linear regression problems.
- Note on supported columns: `linear_model` treats categorical columns as
+ Note on supported columns: `LinearModel` treats categorical columns as
`indicator_column`s. To be specific, assume the input as `SparseTensor` looks
like:
@@ -486,308 +474,189 @@ def linear_model(features,
keywords = categorical_column_with_hash_bucket("keywords", 10K)
keywords_price = crossed_column('keywords', price_buckets, ...)
columns = [price_buckets, keywords, keywords_price ...]
+ linear_model = LinearModel(columns)
+
features = tf.parse_example(..., features=make_parse_example_spec(columns))
- prediction = linear_model(features, columns)
+ prediction = linear_model(features)
```
-
- Args:
- features: A mapping from key to tensors. `_FeatureColumn`s look up via these
- keys. For example `numeric_column('price')` will look at 'price' key in
- this dict. Values are `Tensor` or `SparseTensor` depending on
- corresponding `_FeatureColumn`.
- feature_columns: An iterable containing the FeatureColumns to use as inputs
- to your model. All items should be instances of classes derived from
- `_FeatureColumn`s.
- units: An integer, dimensionality of the output space. Default value is 1.
- sparse_combiner: A string specifying how to reduce if a categorical column
- is multivalent. Except `numeric_column`, almost all columns passed to
- `linear_model` are considered as categorical columns. It combines each
- categorical column independently. Currently "mean", "sqrtn" and "sum" are
- supported, with "sum" the default for linear model. "sqrtn" often achieves
- good accuracy, in particular with bag-of-words columns.
- * "sum": do not normalize features in the column
- * "mean": do l1 normalization on features in the column
- * "sqrtn": do l2 normalization on features in the column
- For example, for two features represented as the categorical columns:
-
- ```python
- # Feature 1
-
- shape = [2, 2]
- {
- [0, 0]: "a"
- [0, 1]: "b"
- [1, 0]: "c"
- }
-
- # Feature 2
-
- shape = [2, 3]
- {
- [0, 0]: "d"
- [1, 0]: "e"
- [1, 1]: "f"
- [1, 2]: "g"
- }
- ```
- with `sparse_combiner` as "mean", the linear model outputs conceptly are:
- ```
- y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0
- y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1
- ```
- where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight
- assigned to the presence of `x` in the input features.
- weight_collections: A list of collection names to which the Variable will be
- added. Note that, variables will also be added to collections
- `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
- trainable: If `True` also add the variable to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- cols_to_vars: If not `None`, must be a dictionary that will be filled with a
- mapping from `_FeatureColumn` to associated list of `Variable`s. For
- example, after the call, we might have cols_to_vars = {
- _NumericColumn(
- key='numeric_feature1', shape=(1,):
- [<tf.Variable 'linear_model/price2/weights:0' shape=(1, 1)>],
- 'bias': [<tf.Variable 'linear_model/bias_weights:0' shape=(1,)>],
- _NumericColumn(
- key='numeric_feature2', shape=(2,)):
- [<tf.Variable 'linear_model/price1/weights:0' shape=(2, 1)>]}
- If a column creates no variables, its value will be an empty list. Note
- that cols_to_vars will also contain a string key 'bias' that maps to a
- list of Variables.
-
- Returns:
- A `Tensor` which represents predictions/logits of a linear model. Its shape
- is (batch_size, units) and its dtype is `float32`.
-
- Raises:
- ValueError: if an item in `feature_columns` is neither a `_DenseColumn`
- nor `_CategoricalColumn`.
- """
- with variable_scope.variable_scope(None, 'linear_model') as vs:
- model_name = _strip_leading_slashes(vs.name)
- linear_model_layer = _LinearModel(
- feature_columns=feature_columns,
- units=units,
- sparse_combiner=sparse_combiner,
- weight_collections=weight_collections,
- trainable=trainable,
- name=model_name)
- retval = linear_model_layer(features) # pylint: disable=not-callable
- if cols_to_vars is not None:
- cols_to_vars.update(linear_model_layer.cols_to_vars())
- return retval
-
-
-def _add_to_collections(var, weight_collections):
- """Adds a var to the list of weight_collections provided.
-
- Handles the case for partitioned and non-partitioned variables.
-
- Args:
- var: A variable or Partitioned Variable.
- weight_collections: List of collections to add variable to.
- """
- for weight_collection in weight_collections:
- # The layer self.add_variable call already adds it to GLOBAL_VARIABLES.
- if weight_collection == ops.GraphKeys.GLOBAL_VARIABLES:
- continue
- # TODO(rohanj): Explore adding a _get_variable_list method on `Variable`
- # so that we don't have to do this check.
- if isinstance(var, variables.PartitionedVariable):
- for constituent_var in list(var):
- ops.add_to_collection(weight_collection, constituent_var)
- else:
- ops.add_to_collection(weight_collection, var)
-
-
-class _FCLinearWrapper(base.Layer):
- """Wraps a _FeatureColumn in a layer for use in a linear model.
-
- See `linear_model` above.
"""
def __init__(self,
- feature_column,
+ feature_columns,
units=1,
sparse_combiner='sum',
- weight_collections=None,
trainable=True,
name=None,
+ shared_state_manager=None,
**kwargs):
- super(_FCLinearWrapper, self).__init__(
- trainable=trainable, name=name, **kwargs)
- self._feature_column = feature_column
- self._units = units
- self._sparse_combiner = sparse_combiner
- self._weight_collections = weight_collections
+ """Constructs a LinearModel.
- def build(self, _):
- if isinstance(self._feature_column, fc_old._CategoricalColumn): # pylint: disable=protected-access
- weight = self.add_variable(
- name='weights',
- shape=(self._feature_column._num_buckets, self._units), # pylint: disable=protected-access
- initializer=init_ops.zeros_initializer(),
- trainable=self.trainable)
- else:
- num_elements = self._feature_column._variable_shape.num_elements() # pylint: disable=protected-access
- weight = self.add_variable(
- name='weights',
- shape=[num_elements, self._units],
- initializer=init_ops.zeros_initializer(),
- trainable=self.trainable)
- _add_to_collections(weight, self._weight_collections)
- self._weight_var = weight
- self.built = True
-
- def call(self, builder):
- weighted_sum = fc_old._create_weighted_sum( # pylint: disable=protected-access
- column=self._feature_column,
- builder=builder,
- units=self._units,
- sparse_combiner=self._sparse_combiner,
- weight_collections=self._weight_collections,
- trainable=self.trainable,
- weight_var=self._weight_var)
- return weighted_sum
+ Args:
+ feature_columns: An iterable containing the FeatureColumns to use as
+ inputs to your model. All items should be instances of classes derived
+ from `_FeatureColumn`s.
+ units: An integer, dimensionality of the output space. Default value is 1.
+ sparse_combiner: A string specifying how to reduce if a categorical column
+ is multivalent. Except `numeric_column`, almost all columns passed to
+ `linear_model` are considered as categorical columns. It combines each
+ categorical column independently. Currently "mean", "sqrtn" and "sum"
+ are supported, with "sum" the default for linear model. "sqrtn" often
+ achieves good accuracy, in particular with bag-of-words columns.
+ * "sum": do not normalize features in the column
+ * "mean": do l1 normalization on features in the column
+ * "sqrtn": do l2 normalization on features in the column
+ For example, for two features represented as the categorical columns:
+
+ ```python
+ # Feature 1
+
+ shape = [2, 2]
+ {
+ [0, 0]: "a"
+ [0, 1]: "b"
+ [1, 0]: "c"
+ }
+
+ # Feature 2
+
+ shape = [2, 3]
+ {
+ [0, 0]: "d"
+ [1, 0]: "e"
+ [1, 1]: "f"
+ [1, 2]: "g"
+ }
+ ```
+
+ with `sparse_combiner` as "mean", the linear model outputs conceptly are
+ ```
+ y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0
+ y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1
+ ```
+ where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight
+ assigned to the presence of `x` in the input features.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: Name to give to the Linear Model. All variables and ops created will
+ be scoped by this name.
+ shared_state_manager: SharedEmbeddingStateManager that manages the state
+ of SharedEmbeddingColumns. For more info, look at `FeatureLayer`.
+ **kwargs: Keyword arguments to construct a layer.
+ Raises:
+ ValueError: if an item in `feature_columns` is neither a `DenseColumn`
+ nor `CategoricalColumn`.
+ """
+ super(LinearModel, self).__init__(name=name, trainable=trainable, **kwargs)
-class _BiasLayer(base.Layer):
- """A layer for the bias term.
- """
+ self._feature_columns = _normalize_feature_columns(feature_columns)
+ self._feature_columns = sorted(self._feature_columns, key=lambda x: x.name)
+ for column in self._feature_columns:
+ if not isinstance(column, (DenseColumn, CategoricalColumn)):
+ raise ValueError(
+ 'Items of feature_columns must be either a '
+ 'DenseColumn or CategoricalColumn. Given: {}'.format(column))
- def __init__(self,
- units=1,
- trainable=True,
- weight_collections=None,
- name=None,
- **kwargs):
- super(_BiasLayer, self).__init__(trainable=trainable, name=name, **kwargs)
self._units = units
- self._weight_collections = weight_collections
-
- def build(self, _):
- self._bias_variable = self.add_variable(
- 'bias_weights',
- shape=[self._units],
- initializer=init_ops.zeros_initializer(),
- trainable=self.trainable)
- _add_to_collections(self._bias_variable, self._weight_collections)
- self.built = True
-
- def call(self, _):
- return self._bias_variable
+ self._sparse_combiner = sparse_combiner
+ self._state_manager = _StateManagerImpl(self, self.trainable)
+ self._shared_state_manager = shared_state_manager
+ self._bias_variable = None
-def _get_expanded_variable_list(var_list):
- returned_list = []
- for variable in var_list:
- if (isinstance(variable, variables.Variable) or
- resource_variable_ops.is_resource_variable(variable)):
- returned_list.append(variable) # Single variable case.
- else: # Must be a PartitionedVariable, so convert into a list.
- returned_list.extend(list(variable))
- return returned_list
+ def build(self, _):
+ # Create state for shared embedding columns.
+ for column in self._feature_columns:
+ if isinstance(column, SharedEmbeddingColumn):
+ column.create_state(self._shared_state_manager)
+ # We need variable scopes for now because we want the variable partitioning
+ # information to percolate down. We also use _pure_variable_scope's here
+ # since we want to open up a name_scope in the `call` method while creating
+ # the ops.
+ with variable_scope._pure_variable_scope(self.name): # pylint: disable=protected-access
+ for column in self._feature_columns:
+ with variable_scope._pure_variable_scope(column.name): # pylint: disable=protected-access
+ # Create the state for each feature column
+ if not isinstance(column, SharedEmbeddingColumn):
+ column.create_state(self._state_manager)
+
+ # Create a weight variable for each column.
+ if isinstance(column, CategoricalColumn):
+ first_dim = column.num_buckets
+ else:
+ first_dim = column.variable_shape.num_elements()
+ self._state_manager.create_variable(
+ column,
+ name='weights',
+ dtype=dtypes.float32,
+ shape=(first_dim, self._units),
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable)
+
+ # Create a bias variable.
+ self._bias_variable = self.add_variable(
+ name='bias_weights',
+ dtype=dtypes.float32,
+ shape=[self._units],
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable,
+ # TODO(rohanj): Get rid of this hack once we have a mechanism for
+ # specifying a default partitioner for an entire layer. In that case,
+ # the default getter for Layers should work.
+ getter=variable_scope.get_variable)
-def _strip_leading_slashes(name):
- return name.rsplit('/', 1)[-1]
+ super(LinearModel, self).build(None)
+ def call(self, features):
+ """Returns a `Tensor` the represents the predictions of a linear model.
-class _LinearModel(training.Model):
- """Creates a linear model using feature columns.
+ Args:
+ features: A mapping from key to tensors. `_FeatureColumn`s look up via
+ these keys. For example `numeric_column('price')` will look at 'price'
+ key in this dict. Values are `Tensor` or `SparseTensor` depending on
+ corresponding `_FeatureColumn`.
- See `linear_model` for details.
- """
+ Returns:
+ A `Tensor` which represents predictions/logits of a linear model. Its
+ shape is (batch_size, units) and its dtype is `float32`.
- def __init__(self,
- feature_columns,
- units=1,
- sparse_combiner='sum',
- weight_collections=None,
- trainable=True,
- name=None,
- **kwargs):
- super(_LinearModel, self).__init__(name=name, **kwargs)
- self._feature_columns = fc_old._normalize_feature_columns( # pylint: disable=protected-access
- feature_columns)
- self._weight_collections = list(weight_collections or [])
- if ops.GraphKeys.GLOBAL_VARIABLES not in self._weight_collections:
- self._weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
- if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections:
- self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
-
- column_layers = {}
- for column in sorted(self._feature_columns, key=lambda x: x.name):
- with variable_scope.variable_scope(
- None, default_name=column._var_scope_name) as vs: # pylint: disable=protected-access
- # Having the fully expressed variable scope name ends up doubly
- # expressing the outer scope (scope with which this method was called)
- # in the name of the variable that would get created.
- column_name = _strip_leading_slashes(vs.name)
- column_layer = _FCLinearWrapper(column, units, sparse_combiner,
- self._weight_collections, trainable,
- column_name, **kwargs)
- column_layers[column_name] = column_layer
- self._column_layers = self._add_layers(column_layers)
- self._bias_layer = _BiasLayer(
- units=units,
- trainable=trainable,
- weight_collections=self._weight_collections,
- name='bias_layer',
- **kwargs)
- self._cols_to_vars = {}
-
- def cols_to_vars(self):
- """Returns a dict mapping _FeatureColumns to variables.
-
- See `linear_model` for more information.
- This is not populated till `call` is called i.e. layer is built.
+ Raises:
+ ValueError: If features are not a dictionary.
"""
- return self._cols_to_vars
-
- def call(self, features):
- with variable_scope.variable_scope(self.name):
- for column in self._feature_columns:
- if not isinstance(
- column,
- (
- fc_old._DenseColumn, # pylint: disable=protected-access
- fc_old._CategoricalColumn)): # pylint: disable=protected-access
- raise ValueError(
- 'Items of feature_columns must be either a '
- '_DenseColumn or _CategoricalColumn. Given: {}'.format(column))
- weighted_sums = []
- ordered_columns = []
- builder = fc_old._LazyBuilder(features) # pylint: disable=protected-access
- for layer in sorted(self._column_layers.values(), key=lambda x: x.name):
- column = layer._feature_column # pylint: disable=protected-access
- ordered_columns.append(column)
- weighted_sum = layer(builder)
+ if not isinstance(features, dict):
+ raise ValueError('We expected a dictionary here. Instead we got: ',
+ features)
+ transformation_cache = FeatureTransformationCache(features)
+ weighted_sums = []
+ for column in self._feature_columns:
+ with ops.name_scope(column.name):
+ # All the weights used in the linear model are owned by the state
+ # manager associated with this Linear Model.
+ weight_var = self._state_manager.get_variable(column, 'weights')
+
+ # The embedding weights for the SharedEmbeddingColumn are owned by
+ # the shared_state_manager and so we need to pass that in while
+ # creating the weighted sum. For all other columns, the state is owned
+ # by the Linear Model's state manager.
+ if isinstance(column, SharedEmbeddingColumn):
+ state_manager = self._shared_state_manager
+ else:
+ state_manager = self._state_manager
+ weighted_sum = _create_weighted_sum(
+ column=column,
+ transformation_cache=transformation_cache,
+ state_manager=state_manager,
+ sparse_combiner=self._sparse_combiner,
+ weight_var=weight_var)
weighted_sums.append(weighted_sum)
- self._cols_to_vars[column] = ops.get_collection(
- ops.GraphKeys.GLOBAL_VARIABLES, scope=layer.scope_name)
-
- _verify_static_batch_size_equality(weighted_sums, ordered_columns)
- predictions_no_bias = math_ops.add_n(
- weighted_sums, name='weighted_sum_no_bias')
- predictions = nn_ops.bias_add(
- predictions_no_bias,
- self._bias_layer( # pylint: disable=not-callable
- builder,
- scope=variable_scope.get_variable_scope()), # pylint: disable=not-callable
- name='weighted_sum')
- bias = self._bias_layer.variables[0]
- self._cols_to_vars['bias'] = _get_expanded_variable_list([bias])
- return predictions
- def _add_layers(self, layers):
- # "Magic" required for keras.Model classes to track all the variables in
- # a list of layers.Layer objects.
- # TODO(ashankar): Figure out API so user code doesn't have to do this.
- for name, layer in layers.items():
- setattr(self, 'layer-%s' % name, layer)
- return layers
+ _verify_static_batch_size_equality(weighted_sums, self._feature_columns)
+ predictions_no_bias = math_ops.add_n(
+ weighted_sums, name='weighted_sum_no_bias')
+ predictions = nn_ops.bias_add(
+ predictions_no_bias, self._bias_variable, name='weighted_sum')
+ return predictions
def _transform_features(features, feature_columns, state_manager):
@@ -2045,58 +1914,40 @@ class DenseColumn(FeatureColumn):
pass
-def _create_weighted_sum(column,
- transformation_cache,
- state_manager,
- units,
- sparse_combiner,
- weight_collections,
- trainable,
- weight_var=None):
+def is_feature_column_v2(feature_columns):
+ """Returns True if all feature columns are V2."""
+ for feature_column in feature_columns:
+ if not isinstance(feature_column, FeatureColumn):
+ return False
+ return True
+
+
+def _create_weighted_sum(column, transformation_cache, state_manager,
+ sparse_combiner, weight_var):
"""Creates a weighted sum for a dense/categorical column for linear_model."""
if isinstance(column, CategoricalColumn):
return _create_categorical_column_weighted_sum(
column=column,
transformation_cache=transformation_cache,
state_manager=state_manager,
- units=units,
sparse_combiner=sparse_combiner,
- weight_collections=weight_collections,
- trainable=trainable,
weight_var=weight_var)
else:
return _create_dense_column_weighted_sum(
column=column,
transformation_cache=transformation_cache,
state_manager=state_manager,
- units=units,
- weight_collections=weight_collections,
- trainable=trainable,
weight_var=weight_var)
-def _create_dense_column_weighted_sum(column,
- transformation_cache,
- state_manager,
- units,
- weight_collections,
- trainable,
- weight_var=None):
+def _create_dense_column_weighted_sum(column, transformation_cache,
+ state_manager, weight_var):
"""Create a weighted sum of a dense column for linear_model."""
tensor = column.get_dense_tensor(transformation_cache, state_manager)
num_elements = column.variable_shape.num_elements()
batch_size = array_ops.shape(tensor)[0]
tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
- if weight_var is not None:
- weight = weight_var
- else:
- weight = variable_scope.get_variable(
- name='weights',
- shape=[num_elements, units],
- initializer=init_ops.zeros_initializer(),
- trainable=trainable,
- collections=weight_collections)
- return math_ops.matmul(tensor, weight, name='weighted_sum')
+ return math_ops.matmul(tensor, weight_var, name='weighted_sum')
class CategoricalColumn(FeatureColumn):
@@ -2137,14 +1988,8 @@ class CategoricalColumn(FeatureColumn):
pass
-def _create_categorical_column_weighted_sum(column,
- transformation_cache,
- state_manager,
- units,
- sparse_combiner,
- weight_collections,
- trainable,
- weight_var=None):
+def _create_categorical_column_weighted_sum(
+ column, transformation_cache, state_manager, sparse_combiner, weight_var):
# pylint: disable=g-doc-return-or-yield,g-doc-args
"""Create a weighted sum of a categorical column for linear_model.
@@ -2183,17 +2028,8 @@ def _create_categorical_column_weighted_sum(column,
weight_tensor = sparse_ops.sparse_reshape(
weight_tensor, [array_ops.shape(weight_tensor)[0], -1])
- if weight_var is not None:
- weight = weight_var
- else:
- weight = variable_scope.get_variable(
- name='weights',
- shape=(column.num_buckets, units),
- initializer=init_ops.zeros_initializer(),
- trainable=trainable,
- collections=weight_collections)
return _safe_embedding_lookup_sparse(
- weight,
+ weight_var,
id_tensor,
sparse_weights=weight_tensor,
combiner=sparse_combiner,
@@ -2333,7 +2169,7 @@ class FeatureTransformationCache(object):
# Input_tensor must have rank 1.
if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
return sparse_ops.sparse_reshape(
- input_tensor, [array_ops.shape(input_tensor)[0], -1])
+ input_tensor, [array_ops.shape(input_tensor)[0], 1])
else:
return array_ops.expand_dims(input_tensor, -1)
@@ -2782,6 +2618,12 @@ class SharedEmbeddingStateManager(Layer):
return self._var_dict[name]
+def maybe_create_shared_state_manager(feature_columns):
+ if is_feature_column_v2(feature_columns):
+ return SharedEmbeddingStateManager()
+ return None
+
+
class SharedEmbeddingColumn(
DenseColumn, SequenceDenseColumn,
collections.namedtuple(
@@ -2822,6 +2664,10 @@ class SharedEmbeddingColumn(
def create_state(self, state_manager):
"""Creates the shared embedding lookup variable."""
+ if not isinstance(state_manager, SharedEmbeddingStateManager):
+ raise ValueError('Expected state_manager to be of type '
+ 'SharedEmbeddingStateManager. Obtained type: {}'.format(
+ type(state_manager)))
embedding_shape = (self.categorical_column.num_buckets, self.dimension)
state_manager.create_variable(
name=self.shared_collection_name,
diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py
index 58168e0f9e..a13a5010e1 100644
--- a/tensorflow/python/feature_column/feature_column_v2_test.py
+++ b/tensorflow/python/feature_column/feature_column_v2_test.py
@@ -31,9 +31,7 @@ from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.estimator.inputs import numpy_io
-from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.feature_column import feature_column_v2 as fc
-from tensorflow.python.feature_column.feature_column_v2 import _LinearModel
from tensorflow.python.feature_column.feature_column_v2 import _transform_features
from tensorflow.python.feature_column.feature_column_v2 import FeatureColumn
from tensorflow.python.feature_column.feature_column_v2 import FeatureLayer
@@ -48,7 +46,6 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
@@ -177,6 +174,22 @@ class LazyColumnTest(test.TestCase):
TypeError, '"key" must be either a "str" or "FeatureColumn".'):
transformation_cache.get(NotAFeatureColumn(), None)
+ def test_expand_dim_rank_1_sparse_tensor_empty_batch(self):
+ # empty 1-D sparse tensor:
+ transformation_cache = FeatureTransformationCache(
+ features={
+ 'a':
+ sparse_tensor.SparseTensor(
+ indices=np.reshape(np.array([], dtype=np.int64), (0, 1)),
+ dense_shape=[0],
+ values=np.array([]))
+ })
+ with self.cached_session():
+ spv = transformation_cache.get('a', None).eval()
+ self.assertAllEqual(np.array([0, 1], dtype=np.int64), spv.dense_shape)
+ self.assertAllEqual(
+ np.reshape(np.array([], dtype=np.int64), (0, 2)), spv.indices)
+
class NumericColumnTest(test.TestCase):
@@ -344,26 +357,12 @@ class NumericColumnTest(test.TestCase):
self.assertEqual(a.default_value, ((3., 2.),))
def test_linear_model(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default():
- features = {'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- self.assertAllClose([[0.]], price_var.eval())
- self.assertAllClose([[0.], [0.]], predictions.eval())
- sess.run(price_var.assign([[10.]]))
- self.assertAllClose([[10.], [50.]], predictions.eval())
-
- def test_keras_linear_model(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- predictions = get_keras_linear_model_predictions(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
self.assertAllClose([[0.]], price_var.eval())
@@ -548,13 +547,13 @@ class BucketizedColumnTest(test.TestCase):
def test_linear_model_one_input_value(self):
"""Tests linear_model() for input with shape=[1]."""
- price = fc_old.numeric_column('price', shape=[1])
- bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ price = fc.numeric_column('price', shape=[1])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
features = {'price': [[-1.], [1.], [5.], [6.]]}
- predictions = fc.linear_model(features, [bucketized_price])
- bias = get_linear_model_bias()
- bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ model = fc.LinearModel([bucketized_price])
+ predictions = model(features)
+ bucketized_price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
# One weight variable per bucket, all initialized to zero.
@@ -573,13 +572,13 @@ class BucketizedColumnTest(test.TestCase):
def test_linear_model_two_input_values(self):
"""Tests linear_model() for input with shape=[2]."""
- price = fc_old.numeric_column('price', shape=[2])
- bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
features = {'price': [[-1., 1.], [5., 6.]]}
- predictions = fc.linear_model(features, [bucketized_price])
- bias = get_linear_model_bias()
- bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ model = fc.LinearModel([bucketized_price])
+ predictions = model(features)
+ bucketized_price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
# One weight per bucket per input column, all initialized to zero.
@@ -600,62 +599,6 @@ class BucketizedColumnTest(test.TestCase):
sess.run(bias.assign([1.]))
self.assertAllClose([[81.], [141.]], predictions.eval())
- def test_keras_linear_model_one_input_value(self):
- """Tests _LinearModel for input with shape=[1]."""
- price = fc_old.numeric_column('price', shape=[1])
- bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
- with ops.Graph().as_default():
- features = {'price': [[-1.], [1.], [5.], [6.]]}
- predictions = get_keras_linear_model_predictions(features,
- [bucketized_price])
- bias = get_linear_model_bias()
- bucketized_price_var = get_linear_model_column_var(bucketized_price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- # One weight variable per bucket, all initialized to zero.
- self.assertAllClose([[0.], [0.], [0.], [0.], [0.]],
- bucketized_price_var.eval())
- self.assertAllClose([[0.], [0.], [0.], [0.]], predictions.eval())
- sess.run(
- bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.]]))
- # price -1. is in the 0th bucket, whose weight is 10.
- # price 1. is in the 1st bucket, whose weight is 20.
- # price 5. is in the 3rd bucket, whose weight is 40.
- # price 6. is in the 4th bucket, whose weight is 50.
- self.assertAllClose([[10.], [20.], [40.], [50.]], predictions.eval())
- sess.run(bias.assign([1.]))
- self.assertAllClose([[11.], [21.], [41.], [51.]], predictions.eval())
-
- def test_keras_linear_model_two_input_values(self):
- """Tests _LinearModel for input with shape=[2]."""
- price = fc_old.numeric_column('price', shape=[2])
- bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
- with ops.Graph().as_default():
- features = {'price': [[-1., 1.], [5., 6.]]}
- predictions = get_keras_linear_model_predictions(features,
- [bucketized_price])
- bias = get_linear_model_bias()
- bucketized_price_var = get_linear_model_column_var(bucketized_price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- # One weight per bucket per input column, all initialized to zero.
- self.assertAllClose(
- [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]],
- bucketized_price_var.eval())
- self.assertAllClose([[0.], [0.]], predictions.eval())
- sess.run(
- bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.],
- [60.], [70.], [80.], [90.], [100.]]))
- # 1st example:
- # price -1. is in the 0th bucket, whose weight is 10.
- # price 1. is in the 6th bucket, whose weight is 70.
- # 2nd example:
- # price 5. is in the 3rd bucket, whose weight is 40.
- # price 6. is in the 9th bucket, whose weight is 100.
- self.assertAllClose([[80.], [140.]], predictions.eval())
- sess.run(bias.assign([1.]))
- self.assertAllClose([[81.], [141.]], predictions.eval())
-
class HashedCategoricalColumnTest(test.TestCase):
@@ -836,39 +779,18 @@ class HashedCategoricalColumnTest(test.TestCase):
transformation_cache.get(hashed_sparse, None), id_weight_pair.id_tensor)
def test_linear_model(self):
- wire_column = fc_old.categorical_column_with_hash_bucket('wire', 4)
- self.assertEqual(4, wire_column._num_buckets)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- wire_column.name: sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
- # 'marlo' -> 3: wire_var[3] = 4
- # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
- self.assertAllClose(((4.,), (6.,)), predictions.eval())
-
- def test_keras_linear_model(self):
- wire_column = fc_old.categorical_column_with_hash_bucket('wire', 4)
- self.assertEqual(4, wire_column._num_buckets)
+ wire_column = fc.categorical_column_with_hash_bucket('wire', 4)
+ self.assertEqual(4, wire_column.num_buckets)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((wire_column,))
+ predictions = model({
wire_column.name:
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
+ })
+ wire_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
@@ -1087,93 +1009,12 @@ class CrossedColumnTest(test.TestCase):
Uses data from test_get_sparse_tesnsors_simple.
"""
- a = fc_old.numeric_column('a', dtype=dtypes.int32, shape=(2,))
- b = fc_old.bucketized_column(a, boundaries=(0, 1))
- crossed = fc_old.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- 'a': constant_op.constant(((-1., .5), (.5, 1.))),
- 'c': sparse_tensor.SparseTensor(
- indices=((0, 0), (1, 0), (1, 1)),
- values=['cA', 'cB', 'cC'],
- dense_shape=(2, 2)),
- }, (crossed,))
- bias = get_linear_model_bias()
- crossed_var = get_linear_model_column_var(crossed)
- with _initialized_session() as sess:
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(
- ((0.,), (0.,), (0.,), (0.,), (0.,)), crossed_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
- # Expected ids after cross = (1, 0, 1, 3, 4, 2)
- self.assertAllClose(((3.,), (14.,)), predictions.eval())
- sess.run(bias.assign((.1,)))
- self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
-
- def test_linear_model_with_weights(self):
-
- class _TestColumnWithWeights(fc_old._CategoricalColumn):
- """Produces sparse IDs and sparse weights."""
-
- @property
- def name(self):
- return 'test_column'
-
- @property
- def _parse_example_spec(self):
- return {
- self.name: parsing_ops.VarLenFeature(dtypes.int32),
- '{}_weights'.format(self.name): parsing_ops.VarLenFeature(
- dtypes.float32),
- }
-
- @property
- def _num_buckets(self):
- return 5
-
- def _transform_feature(self, inputs):
- return (inputs.get(self.name),
- inputs.get('{}_weights'.format(self.name)))
-
- def _get_sparse_tensors(self, inputs, weight_collections=None,
- trainable=None):
- """Populates both id_tensor and weight_tensor."""
- ids_and_weights = inputs.get(self)
- return fc_old._CategoricalColumn.IdWeightPair(
- id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])
-
- t = _TestColumnWithWeights()
- crossed = fc_old.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
- with ops.Graph().as_default():
- with self.assertRaisesRegexp(
- ValueError,
- 'crossed_column does not support weight_tensor.*{}'.format(t.name)):
- fc.linear_model({
- t.name: sparse_tensor.SparseTensor(
- indices=((0, 0), (1, 0), (1, 1)),
- values=[0, 1, 2],
- dense_shape=(2, 2)),
- '{}_weights'.format(t.name): sparse_tensor.SparseTensor(
- indices=((0, 0), (1, 0), (1, 1)),
- values=[1., 10., 2.],
- dense_shape=(2, 2)),
- 'c': sparse_tensor.SparseTensor(
- indices=((0, 0), (1, 0), (1, 1)),
- values=['cA', 'cB', 'cC'],
- dense_shape=(2, 2)),
- }, (crossed,))
-
- def test_keras_linear_model(self):
- """Tests _LinearModel.
-
- Uses data from test_get_sparse_tesnsors_simple.
- """
- a = fc_old.numeric_column('a', dtype=dtypes.int32, shape=(2,))
- b = fc_old.bucketized_column(a, boundaries=(0, 1))
- crossed = fc_old.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
+ a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
+ b = fc.bucketized_column(a, boundaries=(0, 1))
+ crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((crossed,))
+ predictions = model({
'a':
constant_op.constant(((-1., .5), (.5, 1.))),
'c':
@@ -1181,13 +1022,12 @@ class CrossedColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=['cA', 'cB', 'cC'],
dense_shape=(2, 2)),
- }, (crossed,))
- bias = get_linear_model_bias()
- crossed_var = get_linear_model_column_var(crossed)
+ })
+ crossed_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,), (0.,), (0.,)),
- crossed_var.eval())
+ self.assertAllClose(
+ ((0.,), (0.,), (0.,), (0.,), (0.,)), crossed_var.eval())
self.assertAllClose(((0.,), (0.,)), predictions.eval())
sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
# Expected ids after cross = (1, 0, 1, 3, 4, 2)
@@ -1195,9 +1035,9 @@ class CrossedColumnTest(test.TestCase):
sess.run(bias.assign((.1,)))
self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
- def test_keras_linear_model_with_weights(self):
+ def test_linear_model_with_weights(self):
- class _TestColumnWithWeights(fc_old._CategoricalColumn):
+ class _TestColumnWithWeights(fc.CategoricalColumn):
"""Produces sparse IDs and sparse weights."""
@property
@@ -1205,38 +1045,36 @@ class CrossedColumnTest(test.TestCase):
return 'test_column'
@property
- def _parse_example_spec(self):
+ def parse_example_spec(self):
return {
- self.name:
- parsing_ops.VarLenFeature(dtypes.int32),
- '{}_weights'.format(self.name):
- parsing_ops.VarLenFeature(dtypes.float32),
- }
+ self.name: parsing_ops.VarLenFeature(dtypes.int32),
+ '{}_weights'.format(self.name): parsing_ops.VarLenFeature(
+ dtypes.float32),
+ }
@property
- def _num_buckets(self):
+ def num_buckets(self):
return 5
- def _transform_feature(self, inputs):
- return (inputs.get(self.name),
- inputs.get('{}_weights'.format(self.name)))
+ def transform_feature(self, transformation_cache, state_manager):
+ return (transformation_cache.get(self.name, state_manager),
+ transformation_cache.get('{}_weights'.format(self.name),
+ state_manager))
- def _get_sparse_tensors(self,
- inputs,
- weight_collections=None,
- trainable=None):
+ def get_sparse_tensors(self, transformation_cache, state_manager):
"""Populates both id_tensor and weight_tensor."""
- ids_and_weights = inputs.get(self)
- return fc_old._CategoricalColumn.IdWeightPair(
+ ids_and_weights = transformation_cache.get(self, state_manager)
+ return fc.CategoricalColumn.IdWeightPair(
id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])
t = _TestColumnWithWeights()
- crossed = fc_old.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
+ crossed = fc.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
with ops.Graph().as_default():
with self.assertRaisesRegexp(
ValueError,
'crossed_column does not support weight_tensor.*{}'.format(t.name)):
- get_keras_linear_model_predictions({
+ model = fc.LinearModel((crossed,))
+ model({
t.name:
sparse_tensor.SparseTensor(
indices=((0, 0), (1, 0), (1, 1)),
@@ -1252,37 +1090,7 @@ class CrossedColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=['cA', 'cB', 'cC'],
dense_shape=(2, 2)),
- }, (crossed,))
-
-
-def get_linear_model_bias(name='linear_model'):
- with variable_scope.variable_scope(name, reuse=True):
- return variable_scope.get_variable('bias_weights')
-
-
-def get_linear_model_column_var(column, name='linear_model'):
- return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
- name + '/' + column.name)[0]
-
-
-def get_keras_linear_model_predictions(features,
- feature_columns,
- units=1,
- sparse_combiner='sum',
- weight_collections=None,
- trainable=True,
- cols_to_vars=None):
- keras_linear_model = _LinearModel(
- feature_columns,
- units,
- sparse_combiner,
- weight_collections,
- trainable,
- name='linear_model')
- retval = keras_linear_model(features) # pylint: disable=not-callable
- if cols_to_vars is not None:
- cols_to_vars.update(keras_linear_model.cols_to_vars())
- return retval
+ })
class LinearModelTest(test.TestCase):
@@ -1290,56 +1098,50 @@ class LinearModelTest(test.TestCase):
def test_raises_if_empty_feature_columns(self):
with self.assertRaisesRegexp(ValueError,
'feature_columns must not be empty'):
- fc.linear_model(features={}, feature_columns=[])
+ fc.LinearModel(feature_columns=[])
def test_should_be_feature_column(self):
- with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'):
- fc.linear_model(features={'a': [[0]]}, feature_columns='NotSupported')
+ with self.assertRaisesRegexp(ValueError, 'must be a FeatureColumn'):
+ fc.LinearModel(feature_columns='NotSupported')
def test_should_be_dense_or_categorical_column(self):
- class NotSupportedColumn(fc_old._FeatureColumn):
+ class NotSupportedColumn(fc.FeatureColumn):
@property
def name(self):
return 'NotSupportedColumn'
- def _transform_feature(self, cache):
+ def transform_feature(self, transformation_cache, state_manager):
pass
@property
- def _parse_example_spec(self):
+ def parse_example_spec(self):
pass
with self.assertRaisesRegexp(
- ValueError, 'must be either a _DenseColumn or _CategoricalColumn'):
- fc.linear_model(
- features={'a': [[0]]}, feature_columns=[NotSupportedColumn()])
+ ValueError, 'must be either a DenseColumn or CategoricalColumn'):
+ fc.LinearModel(feature_columns=[NotSupportedColumn()])
def test_does_not_support_dict_columns(self):
with self.assertRaisesRegexp(
ValueError, 'Expected feature_columns to be iterable, found dict.'):
- fc.linear_model(
- features={'a': [[0]]},
- feature_columns={'a': fc_old.numeric_column('a')})
+ fc.LinearModel(feature_columns={'a': fc.numeric_column('a')})
def test_raises_if_duplicate_name(self):
with self.assertRaisesRegexp(
ValueError, 'Duplicate feature column name found for columns'):
- fc.linear_model(
- features={'a': [[0]]},
- feature_columns=[
- fc_old.numeric_column('a'),
- fc_old.numeric_column('a')
- ])
+ fc.LinearModel(
+ feature_columns=[fc.numeric_column('a'),
+ fc.numeric_column('a')])
def test_dense_bias(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
sess.run(price_var.assign([[10.]]))
@@ -1347,16 +1149,16 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[15.], [55.]], predictions.eval())
def test_sparse_bias(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {'wire_cast': wire_tensor}
- predictions = fc.linear_model(features, [wire_cast])
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast])
+ predictions = model(features)
+ wire_cast_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval())
@@ -1365,18 +1167,17 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[1005.], [10015.]], predictions.eval())
def test_dense_and_sparse_bias(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- price = fc_old.numeric_column('price')
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [wire_cast, price])
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([wire_cast, price])
+ predictions = model(features)
+ price_var, wire_cast_var, bias = model.variables
with _initialized_session() as sess:
sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(bias.assign([5.]))
@@ -1386,38 +1187,36 @@ class LinearModelTest(test.TestCase):
def test_dense_and_sparse_column(self):
"""When the column is both dense and sparse, uses sparse tensors."""
- class _DenseAndSparseColumn(fc_old._DenseColumn, fc_old._CategoricalColumn):
+ class _DenseAndSparseColumn(fc.DenseColumn, fc.CategoricalColumn):
@property
def name(self):
return 'dense_and_sparse_column'
@property
- def _parse_example_spec(self):
+ def parse_example_spec(self):
return {self.name: parsing_ops.VarLenFeature(self.dtype)}
- def _transform_feature(self, inputs):
- return inputs.get(self.name)
+ def transform_feature(self, transformation_cache, state_manager):
+ return transformation_cache.get(self.name, state_manager)
@property
- def _variable_shape(self):
+ def variable_shape(self):
raise ValueError('Should not use this method.')
- def _get_dense_tensor(self, inputs, weight_collections=None,
- trainable=None):
+ def get_dense_tensor(self, transformation_cache, state_manager):
raise ValueError('Should not use this method.')
@property
- def _num_buckets(self):
+ def num_buckets(self):
return 4
- def _get_sparse_tensors(self, inputs, weight_collections=None,
- trainable=None):
+ def get_sparse_tensors(self, transformation_cache, state_manager):
sp_tensor = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 0], [1, 1]],
values=[2, 0, 3],
dense_shape=[2, 2])
- return fc_old._CategoricalColumn.IdWeightPair(sp_tensor, None)
+ return fc.CategoricalColumn.IdWeightPair(sp_tensor, None)
dense_and_sparse_column = _DenseAndSparseColumn()
with ops.Graph().as_default():
@@ -1426,10 +1225,9 @@ class LinearModelTest(test.TestCase):
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {dense_and_sparse_column.name: sp_tensor}
- predictions = fc.linear_model(features, [dense_and_sparse_column])
- bias = get_linear_model_bias()
- dense_and_sparse_column_var = get_linear_model_column_var(
- dense_and_sparse_column)
+ model = fc.LinearModel([dense_and_sparse_column])
+ predictions = model(features)
+ dense_and_sparse_column_var, bias = model.variables
with _initialized_session() as sess:
sess.run(dense_and_sparse_column_var.assign(
[[10.], [100.], [1000.], [10000.]]))
@@ -1437,12 +1235,12 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[1005.], [10015.]], predictions.eval())
def test_dense_multi_output(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [price], units=3)
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price], units=3)
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose(np.zeros((3,)), bias.eval())
self.assertAllClose(np.zeros((1, 3)), price_var.eval())
@@ -1452,16 +1250,16 @@ class LinearModelTest(test.TestCase):
predictions.eval())
def test_sparse_multi_output(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {'wire_cast': wire_tensor}
- predictions = fc.linear_model(features, [wire_cast], units=3)
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast], units=3)
+ predictions = model(features)
+ wire_cast_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose(np.zeros((3,)), bias.eval())
self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval())
@@ -1474,18 +1272,19 @@ class LinearModelTest(test.TestCase):
predictions.eval())
def test_dense_multi_dimension(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1., 2.], [5., 6.]]}
- predictions = fc.linear_model(features, [price])
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ predictions = model(features)
+ price_var, _ = model.variables
with _initialized_session() as sess:
self.assertAllClose([[0.], [0.]], price_var.eval())
sess.run(price_var.assign([[10.], [100.]]))
self.assertAllClose([[210.], [650.]], predictions.eval())
def test_sparse_multi_rank(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = array_ops.sparse_placeholder(dtypes.string)
wire_value = sparse_tensor.SparseTensorValue(
@@ -1493,8 +1292,9 @@ class LinearModelTest(test.TestCase):
indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]],
dense_shape=[2, 2, 2])
features = {'wire_cast': wire_tensor}
- predictions = fc.linear_model(features, [wire_cast])
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast])
+ predictions = model(features)
+ wire_cast_var, _ = model.variables
with _initialized_session() as sess:
self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval())
self.assertAllClose(
@@ -1506,25 +1306,24 @@ class LinearModelTest(test.TestCase):
predictions.eval(feed_dict={wire_tensor: wire_value}))
def test_sparse_combiner(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {'wire_cast': wire_tensor}
- predictions = fc.linear_model(
- features, [wire_cast], sparse_combiner='mean')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast], sparse_combiner='mean')
+ predictions = model(features)
+ wire_cast_var, bias = model.variables
with _initialized_session() as sess:
sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[1005.], [5010.]], predictions.eval())
def test_sparse_combiner_with_negative_weights(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- wire_cast_weights = fc_old.weighted_categorical_column(wire_cast, 'weights')
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast_weights = fc.weighted_categorical_column(wire_cast, 'weights')
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
@@ -1535,22 +1334,21 @@ class LinearModelTest(test.TestCase):
'wire_cast': wire_tensor,
'weights': constant_op.constant([[1., 1., -1.0]])
}
- predictions = fc.linear_model(
- features, [wire_cast_weights], sparse_combiner='sum')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast_weights], sparse_combiner='sum')
+ predictions = model(features)
+ wire_cast_var, bias = model.variables
with _initialized_session() as sess:
sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[1005.], [-9985.]], predictions.eval())
def test_dense_multi_dimension_multi_output(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1., 2.], [5., 6.]]}
- predictions = fc.linear_model(features, [price], units=3)
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price], units=3)
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose(np.zeros((3,)), bias.eval())
self.assertAllClose(np.zeros((2, 3)), price_var.eval())
@@ -1560,21 +1358,22 @@ class LinearModelTest(test.TestCase):
predictions.eval())
def test_raises_if_shape_mismatch(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
with self.assertRaisesRegexp(
Exception,
r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
- fc.linear_model(features, [price])
+ model = fc.LinearModel([price])
+ model(features)
def test_dense_reshaping(self):
- price = fc_old.numeric_column('price', shape=[1, 2])
+ price = fc.numeric_column('price', shape=[1, 2])
with ops.Graph().as_default():
features = {'price': [[[1., 2.]], [[5., 6.]]]}
- predictions = fc.linear_model(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
self.assertAllClose([[0.], [0.]], price_var.eval())
@@ -1583,17 +1382,16 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[210.], [650.]], predictions.eval())
def test_dense_multi_column(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': [[1., 2.], [5., 6.]],
'price2': [[3.], [4.]]
}
- predictions = fc.linear_model(features, [price1, price2])
- bias = get_linear_model_bias()
- price1_var = get_linear_model_column_var(price1)
- price2_var = get_linear_model_column_var(price2)
+ model = fc.LinearModel([price1, price2])
+ predictions = model(features)
+ price1_var, price2_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
self.assertAllClose([[0.], [0.]], price1_var.eval())
@@ -1604,115 +1402,55 @@ class LinearModelTest(test.TestCase):
sess.run(bias.assign([7.]))
self.assertAllClose([[3217.], [4657.]], predictions.eval())
- def test_fills_cols_to_vars(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
- cols_to_vars = {}
- fc.linear_model(features, [price1, price2], cols_to_vars=cols_to_vars)
- bias = get_linear_model_bias()
- price1_var = get_linear_model_column_var(price1)
- price2_var = get_linear_model_column_var(price2)
- self.assertAllEqual(cols_to_vars['bias'], [bias])
- self.assertAllEqual(cols_to_vars[price1], [price1_var])
- self.assertAllEqual(cols_to_vars[price2], [price2_var])
-
- def test_fills_cols_to_vars_partitioned_variables(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2', shape=3)
- with ops.Graph().as_default():
- features = {
- 'price1': [[1., 2.], [6., 7.]],
- 'price2': [[3., 4., 5.], [8., 9., 10.]]
- }
- cols_to_vars = {}
- with variable_scope.variable_scope(
- 'linear',
- partitioner=partitioned_variables.fixed_size_partitioner(2, axis=0)):
- fc.linear_model(features, [price1, price2], cols_to_vars=cols_to_vars)
- with _initialized_session():
- self.assertEqual([0.], cols_to_vars['bias'][0].eval())
- # Partitioning shards the [2, 1] price1 var into 2 [1, 1] Variables.
- self.assertAllEqual([[0.]], cols_to_vars[price1][0].eval())
- self.assertAllEqual([[0.]], cols_to_vars[price1][1].eval())
- # Partitioning shards the [3, 1] price2 var into a [2, 1] Variable and
- # a [1, 1] Variable.
- self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval())
- self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval())
-
- def test_dense_collection(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default() as g:
- features = {'price': [[1.], [5.]]}
- fc.linear_model(features, [price], weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- self.assertIn(bias, my_vars)
- self.assertIn(price_var, my_vars)
-
- def test_sparse_collection(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- features = {'wire_cast': wire_tensor}
- fc.linear_model(
- features, [wire_cast], weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- self.assertIn(bias, my_vars)
- self.assertIn(wire_cast_var, my_vars)
-
def test_dense_trainable_default(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default() as g:
features = {'price': [[1.], [5.]]}
- fc.linear_model(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ model(features)
+ price_var, bias = model.variables
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertIn(bias, trainable_vars)
self.assertIn(price_var, trainable_vars)
def test_sparse_trainable_default(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default() as g:
wire_tensor = sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
features = {'wire_cast': wire_tensor}
- fc.linear_model(features, [wire_cast])
+ model = fc.LinearModel([wire_cast])
+ model(features)
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ wire_cast_var, bias = model.variables
self.assertIn(bias, trainable_vars)
self.assertIn(wire_cast_var, trainable_vars)
def test_dense_trainable_false(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default() as g:
features = {'price': [[1.], [5.]]}
- fc.linear_model(features, [price], trainable=False)
+ model = fc.LinearModel([price], trainable=False)
+ model(features)
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertEqual([], trainable_vars)
def test_sparse_trainable_false(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default() as g:
wire_tensor = sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
features = {'wire_cast': wire_tensor}
- fc.linear_model(features, [wire_cast], trainable=False)
+ model = fc.LinearModel([wire_cast], trainable=False)
+ model(features)
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertEqual([], trainable_vars)
def test_column_order(self):
- price_a = fc_old.numeric_column('price_a')
- price_b = fc_old.numeric_column('price_b')
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
+ price_a = fc.numeric_column('price_a')
+ price_b = fc.numeric_column('price_b')
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
features = {
'price_a': [[1.]],
'price_b': [[3.]],
@@ -1720,15 +1458,15 @@ class LinearModelTest(test.TestCase):
sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
}
- fc.linear_model(
- features, [price_a, wire_cast, price_b],
- weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
+ model = fc.LinearModel([price_a, wire_cast, price_b])
+ model(features)
+
+ my_vars = model.variables
self.assertIn('price_a', my_vars[0].name)
self.assertIn('price_b', my_vars[1].name)
self.assertIn('wire_cast', my_vars[2].name)
- with ops.Graph().as_default() as g:
+ with ops.Graph().as_default():
features = {
'price_a': [[1.]],
'price_b': [[3.]],
@@ -1736,17 +1474,45 @@ class LinearModelTest(test.TestCase):
sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
}
- fc.linear_model(
- features, [wire_cast, price_b, price_a],
- weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
+ model = fc.LinearModel([wire_cast, price_b, price_a])
+ model(features)
+
+ my_vars = model.variables
self.assertIn('price_a', my_vars[0].name)
self.assertIn('price_b', my_vars[1].name)
self.assertIn('wire_cast', my_vars[2].name)
+ def test_variable_names(self):
+ price1 = fc.numeric_column('price1')
+ dense_feature = fc.numeric_column('dense_feature')
+ dense_feature_bucketized = fc.bucketized_column(
+ dense_feature, boundaries=[0.])
+ some_sparse_column = fc.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5)
+ some_embedding_column = fc.embedding_column(
+ some_sparse_column, dimension=10)
+ all_cols = [price1, dense_feature_bucketized, some_embedding_column]
+
+ with ops.Graph().as_default():
+ model = fc.LinearModel(all_cols)
+ features = {
+ 'price1': [[3.], [4.]],
+ 'dense_feature': [[-1.], [4.]],
+ 'sparse_feature': [['a'], ['x']],
+ }
+ model(features)
+ variable_names = [var.name for var in model.variables]
+ self.assertItemsEqual([
+ 'linear_model/dense_feature_bucketized/weights:0',
+ 'linear_model/price1/weights:0',
+ 'linear_model/sparse_feature_embedding/embedding_weights:0',
+ 'linear_model/sparse_feature_embedding/weights:0',
+ 'linear_model/bias_weights:0',
+ ], variable_names)
+
def test_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': [[1.], [5.], [7.]], # batchsize = 3
@@ -1755,12 +1521,13 @@ class LinearModelTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- fc.linear_model(features, [price1, price2])
+ model = fc.LinearModel([price1, price2])
+ model(features)
def test_subset_of_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- price3 = fc_old.numeric_column('price3')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ price3 = fc.numeric_column('price3')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
@@ -1770,17 +1537,19 @@ class LinearModelTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- fc.linear_model(features, [price1, price2, price3])
+ model = fc.LinearModel([price1, price2, price3])
+ model(features)
def test_runtime_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
'price2': [[3.], [4.]] # batchsize = 2
}
- predictions = fc.linear_model(features, [price1, price2])
+ model = fc.LinearModel([price1, price2])
+ predictions = model(features)
with _initialized_session() as sess:
with self.assertRaisesRegexp(errors.OpError,
'must have the same size and shape'):
@@ -1788,14 +1557,15 @@ class LinearModelTest(test.TestCase):
predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]})
def test_runtime_batch_size_matches(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
}
- predictions = fc.linear_model(features, [price1, price2])
+ model = fc.LinearModel([price1, price2])
+ predictions = model(features)
with _initialized_session() as sess:
sess.run(
predictions,
@@ -1805,14 +1575,14 @@ class LinearModelTest(test.TestCase):
})
def test_with_numpy_input_fn(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
price, boundaries=[
0.,
10.,
100.,
])
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
input_fn = numpy_io.numpy_input_fn(
@@ -1823,15 +1593,14 @@ class LinearModelTest(test.TestCase):
batch_size=2,
shuffle=False)
features = input_fn()
- net = fc.linear_model(features, [price_buckets, body_style])
+ model = fc.LinearModel([price_buckets, body_style])
+ net = model(features)
# self.assertEqual(1 + 3 + 5, net.shape[1])
with _initialized_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
+ body_style_var, price_buckets_var, bias = model.variables
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
@@ -1843,14 +1612,14 @@ class LinearModelTest(test.TestCase):
coord.join(threads)
def test_with_1d_sparse_tensor(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
price, boundaries=[
0.,
10.,
100.,
])
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
# Provides 1-dim tensor and dense tensor.
@@ -1864,11 +1633,10 @@ class LinearModelTest(test.TestCase):
self.assertEqual(1, features['price'].shape.ndims)
self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
- net = fc.linear_model(features, [price_buckets, body_style])
+ model = fc.LinearModel([price_buckets, body_style])
+ net = model(features)
with _initialized_session() as sess:
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
+ body_style_var, price_buckets_var, bias = model.variables
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
@@ -1877,16 +1645,16 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
def test_with_1d_unknown_shape_sparse_tensor(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
price, boundaries=[
0.,
10.,
100.,
])
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
- country = fc_old.categorical_column_with_vocabulary_list(
+ country = fc.categorical_column_with_vocabulary_list(
'country', vocabulary_list=['US', 'JP', 'CA'])
# Provides 1-dim tensor and dense tensor.
@@ -1905,10 +1673,9 @@ class LinearModelTest(test.TestCase):
dense_shape=(2,))
country_data = np.array(['US', 'CA'])
- net = fc.linear_model(features, [price_buckets, body_style, country])
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
+ model = fc.LinearModel([price_buckets, body_style, country])
+ net = model(features)
+ body_style_var, _, price_buckets_var, bias = model.variables
with _initialized_session() as sess:
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
@@ -1924,7 +1691,7 @@ class LinearModelTest(test.TestCase):
}))
def test_with_rank_0_feature(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
features = {
'price': constant_op.constant(0),
}
@@ -1932,29 +1699,31 @@ class LinearModelTest(test.TestCase):
# Static rank 0 should fail
with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
- fc.linear_model(features, [price])
+ model = fc.LinearModel([price])
+ model(features)
# Dynamic rank 0 should fail
features = {
'price': array_ops.placeholder(dtypes.float32),
}
- net = fc.linear_model(features, [price])
+ model = fc.LinearModel([price])
+ net = model(features)
self.assertEqual(1, net.shape[1])
with _initialized_session() as sess:
with self.assertRaisesOpError('Feature .* cannot have rank 0'):
sess.run(net, feed_dict={features['price']: np.array(1)})
def test_multiple_linear_models(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features1 = {'price': [[1.], [5.]]}
features2 = {'price': [[2.], [10.]]}
- predictions1 = fc.linear_model(features1, [price])
- predictions2 = fc.linear_model(features2, [price])
- bias1 = get_linear_model_bias(name='linear_model')
- bias2 = get_linear_model_bias(name='linear_model_1')
- price_var1 = get_linear_model_column_var(price, name='linear_model')
- price_var2 = get_linear_model_column_var(price, name='linear_model_1')
+ model1 = fc.LinearModel([price])
+ model2 = fc.LinearModel([price])
+ predictions1 = model1(features1)
+ predictions2 = model2(features2)
+ price_var1, bias1 = model1.variables
+ price_var2, bias2 = model2.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias1.eval())
sess.run(price_var1.assign([[10.]]))
@@ -1966,664 +1735,6 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[25.], [105.]], predictions2.eval())
-class _LinearModelTest(test.TestCase):
-
- def test_raises_if_empty_feature_columns(self):
- with self.assertRaisesRegexp(ValueError,
- 'feature_columns must not be empty'):
- get_keras_linear_model_predictions(features={}, feature_columns=[])
-
- def test_should_be_feature_column(self):
- with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'):
- get_keras_linear_model_predictions(
- features={'a': [[0]]}, feature_columns='NotSupported')
-
- def test_should_be_dense_or_categorical_column(self):
-
- class NotSupportedColumn(fc_old._FeatureColumn):
-
- @property
- def name(self):
- return 'NotSupportedColumn'
-
- def _transform_feature(self, cache):
- pass
-
- @property
- def _parse_example_spec(self):
- pass
-
- with self.assertRaisesRegexp(
- ValueError, 'must be either a _DenseColumn or _CategoricalColumn'):
- get_keras_linear_model_predictions(
- features={'a': [[0]]}, feature_columns=[NotSupportedColumn()])
-
- def test_does_not_support_dict_columns(self):
- with self.assertRaisesRegexp(
- ValueError, 'Expected feature_columns to be iterable, found dict.'):
- fc.linear_model(
- features={'a': [[0]]},
- feature_columns={'a': fc_old.numeric_column('a')})
-
- def test_raises_if_duplicate_name(self):
- with self.assertRaisesRegexp(
- ValueError, 'Duplicate feature column name found for columns'):
- get_keras_linear_model_predictions(
- features={'a': [[0]]},
- feature_columns=[
- fc_old.numeric_column('a'),
- fc_old.numeric_column('a')
- ])
-
- def test_dense_bias(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default():
- features = {'price': [[1.], [5.]]}
- predictions = get_keras_linear_model_predictions(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- sess.run(price_var.assign([[10.]]))
- sess.run(bias.assign([5.]))
- self.assertAllClose([[15.], [55.]], predictions.eval())
-
- def test_sparse_bias(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default():
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {'wire_cast': wire_tensor}
- predictions = get_keras_linear_model_predictions(features, [wire_cast])
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval())
- sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(bias.assign([5.]))
- self.assertAllClose([[1005.], [10015.]], predictions.eval())
-
- def test_dense_and_sparse_bias(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default():
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]}
- predictions = get_keras_linear_model_predictions(features,
- [wire_cast, price])
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(bias.assign([5.]))
- sess.run(price_var.assign([[10.]]))
- self.assertAllClose([[1015.], [10065.]], predictions.eval())
-
- def test_dense_and_sparse_column(self):
- """When the column is both dense and sparse, uses sparse tensors."""
-
- class _DenseAndSparseColumn(fc_old._DenseColumn, fc_old._CategoricalColumn):
-
- @property
- def name(self):
- return 'dense_and_sparse_column'
-
- @property
- def _parse_example_spec(self):
- return {self.name: parsing_ops.VarLenFeature(self.dtype)}
-
- def _transform_feature(self, inputs):
- return inputs.get(self.name)
-
- @property
- def _variable_shape(self):
- raise ValueError('Should not use this method.')
-
- def _get_dense_tensor(self,
- inputs,
- weight_collections=None,
- trainable=None):
- raise ValueError('Should not use this method.')
-
- @property
- def _num_buckets(self):
- return 4
-
- def _get_sparse_tensors(self,
- inputs,
- weight_collections=None,
- trainable=None):
- sp_tensor = sparse_tensor.SparseTensor(
- indices=[[0, 0], [1, 0], [1, 1]],
- values=[2, 0, 3],
- dense_shape=[2, 2])
- return fc_old._CategoricalColumn.IdWeightPair(sp_tensor, None)
-
- dense_and_sparse_column = _DenseAndSparseColumn()
- with ops.Graph().as_default():
- sp_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'],
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {dense_and_sparse_column.name: sp_tensor}
- predictions = get_keras_linear_model_predictions(
- features, [dense_and_sparse_column])
- bias = get_linear_model_bias()
- dense_and_sparse_column_var = get_linear_model_column_var(
- dense_and_sparse_column)
- with _initialized_session() as sess:
- sess.run(
- dense_and_sparse_column_var.assign([[10.], [100.], [1000.],
- [10000.]]))
- sess.run(bias.assign([5.]))
- self.assertAllClose([[1005.], [10015.]], predictions.eval())
-
- def test_dense_multi_output(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default():
- features = {'price': [[1.], [5.]]}
- predictions = get_keras_linear_model_predictions(
- features, [price], units=3)
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose(np.zeros((3,)), bias.eval())
- self.assertAllClose(np.zeros((1, 3)), price_var.eval())
- sess.run(price_var.assign([[10., 100., 1000.]]))
- sess.run(bias.assign([5., 6., 7.]))
- self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]],
- predictions.eval())
-
- def test_sparse_multi_output(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default():
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {'wire_cast': wire_tensor}
- predictions = get_keras_linear_model_predictions(
- features, [wire_cast], units=3)
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- with _initialized_session() as sess:
- self.assertAllClose(np.zeros((3,)), bias.eval())
- self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval())
- sess.run(
- wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.],
- [1000., 1100.,
- 1200.], [10000., 11000., 12000.]]))
- sess.run(bias.assign([5., 6., 7.]))
- self.assertAllClose([[1005., 1106., 1207.], [10015., 11017., 12019.]],
- predictions.eval())
-
- def test_dense_multi_dimension(self):
- price = fc_old.numeric_column('price', shape=2)
- with ops.Graph().as_default():
- features = {'price': [[1., 2.], [5., 6.]]}
- predictions = get_keras_linear_model_predictions(features, [price])
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose([[0.], [0.]], price_var.eval())
- sess.run(price_var.assign([[10.], [100.]]))
- self.assertAllClose([[210.], [650.]], predictions.eval())
-
- def test_sparse_multi_rank(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default():
- wire_tensor = array_ops.sparse_placeholder(dtypes.string)
- wire_value = sparse_tensor.SparseTensorValue(
- values=['omar', 'stringer', 'marlo', 'omar'], # hashed = [2, 0, 3, 2]
- indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]],
- dense_shape=[2, 2, 2])
- features = {'wire_cast': wire_tensor}
- predictions = get_keras_linear_model_predictions(features, [wire_cast])
- wire_cast_var = get_linear_model_column_var(wire_cast)
- with _initialized_session() as sess:
- self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval())
- self.assertAllClose(
- np.zeros((2, 1)),
- predictions.eval(feed_dict={wire_tensor: wire_value}))
- sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
- self.assertAllClose(
- [[1010.], [11000.]],
- predictions.eval(feed_dict={wire_tensor: wire_value}))
-
- def test_sparse_combiner(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default():
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {'wire_cast': wire_tensor}
- predictions = get_keras_linear_model_predictions(
- features, [wire_cast], sparse_combiner='mean')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- with _initialized_session() as sess:
- sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(bias.assign([5.]))
- self.assertAllClose([[1005.], [5010.]], predictions.eval())
-
- def test_dense_multi_dimension_multi_output(self):
- price = fc_old.numeric_column('price', shape=2)
- with ops.Graph().as_default():
- features = {'price': [[1., 2.], [5., 6.]]}
- predictions = get_keras_linear_model_predictions(
- features, [price], units=3)
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose(np.zeros((3,)), bias.eval())
- self.assertAllClose(np.zeros((2, 3)), price_var.eval())
- sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]]))
- sess.run(bias.assign([2., 3., 4.]))
- self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]],
- predictions.eval())
-
- def test_raises_if_shape_mismatch(self):
- price = fc_old.numeric_column('price', shape=2)
- with ops.Graph().as_default():
- features = {'price': [[1.], [5.]]}
- with self.assertRaisesRegexp(
- Exception,
- r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
- get_keras_linear_model_predictions(features, [price])
-
- def test_dense_reshaping(self):
- price = fc_old.numeric_column('price', shape=[1, 2])
- with ops.Graph().as_default():
- features = {'price': [[[1., 2.]], [[5., 6.]]]}
- predictions = get_keras_linear_model_predictions(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- self.assertAllClose([[0.], [0.]], price_var.eval())
- self.assertAllClose([[0.], [0.]], predictions.eval())
- sess.run(price_var.assign([[10.], [100.]]))
- self.assertAllClose([[210.], [650.]], predictions.eval())
-
- def test_dense_multi_column(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
- predictions = get_keras_linear_model_predictions(features,
- [price1, price2])
- bias = get_linear_model_bias()
- price1_var = get_linear_model_column_var(price1)
- price2_var = get_linear_model_column_var(price2)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- self.assertAllClose([[0.], [0.]], price1_var.eval())
- self.assertAllClose([[0.]], price2_var.eval())
- self.assertAllClose([[0.], [0.]], predictions.eval())
- sess.run(price1_var.assign([[10.], [100.]]))
- sess.run(price2_var.assign([[1000.]]))
- sess.run(bias.assign([7.]))
- self.assertAllClose([[3217.], [4657.]], predictions.eval())
-
- def test_fills_cols_to_vars(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
- cols_to_vars = {}
- get_keras_linear_model_predictions(
- features, [price1, price2], cols_to_vars=cols_to_vars)
- bias = get_linear_model_bias()
- price1_var = get_linear_model_column_var(price1)
- price2_var = get_linear_model_column_var(price2)
- self.assertAllEqual(cols_to_vars['bias'], [bias])
- self.assertAllEqual(cols_to_vars[price1], [price1_var])
- self.assertAllEqual(cols_to_vars[price2], [price2_var])
-
- def test_fills_cols_to_vars_partitioned_variables(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2', shape=3)
- with ops.Graph().as_default():
- features = {
- 'price1': [[1., 2.], [6., 7.]],
- 'price2': [[3., 4., 5.], [8., 9., 10.]]
- }
- cols_to_vars = {}
- with variable_scope.variable_scope(
- 'linear',
- partitioner=partitioned_variables.fixed_size_partitioner(2, axis=0)):
- get_keras_linear_model_predictions(
- features, [price1, price2], cols_to_vars=cols_to_vars)
- with _initialized_session():
- self.assertEqual([0.], cols_to_vars['bias'][0].eval())
- # Partitioning shards the [2, 1] price1 var into 2 [1, 1] Variables.
- self.assertAllEqual([[0.]], cols_to_vars[price1][0].eval())
- self.assertAllEqual([[0.]], cols_to_vars[price1][1].eval())
- # Partitioning shards the [3, 1] price2 var into a [2, 1] Variable and
- # a [1, 1] Variable.
- self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval())
- self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval())
-
- def test_dense_collection(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default() as g:
- features = {'price': [[1.], [5.]]}
- get_keras_linear_model_predictions(
- features, [price], weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- self.assertIn(bias, my_vars)
- self.assertIn(price_var, my_vars)
-
- def test_sparse_collection(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- features = {'wire_cast': wire_tensor}
- get_keras_linear_model_predictions(
- features, [wire_cast], weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- self.assertIn(bias, my_vars)
- self.assertIn(wire_cast_var, my_vars)
-
- def test_dense_trainable_default(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default() as g:
- features = {'price': [[1.], [5.]]}
- get_keras_linear_model_predictions(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- self.assertIn(bias, trainable_vars)
- self.assertIn(price_var, trainable_vars)
-
- def test_sparse_trainable_default(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- features = {'wire_cast': wire_tensor}
- get_keras_linear_model_predictions(features, [wire_cast])
- trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- self.assertIn(bias, trainable_vars)
- self.assertIn(wire_cast_var, trainable_vars)
-
- def test_dense_trainable_false(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default() as g:
- features = {'price': [[1.], [5.]]}
- get_keras_linear_model_predictions(features, [price], trainable=False)
- trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- self.assertEqual([], trainable_vars)
-
- def test_sparse_trainable_false(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- features = {'wire_cast': wire_tensor}
- get_keras_linear_model_predictions(features, [wire_cast], trainable=False)
- trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- self.assertEqual([], trainable_vars)
-
- def test_column_order(self):
- price_a = fc_old.numeric_column('price_a')
- price_b = fc_old.numeric_column('price_b')
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- features = {
- 'price_a': [[1.]],
- 'price_b': [[3.]],
- 'wire_cast':
- sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- }
- get_keras_linear_model_predictions(
- features, [price_a, wire_cast, price_b],
- weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- self.assertIn('price_a', my_vars[0].name)
- self.assertIn('price_b', my_vars[1].name)
- self.assertIn('wire_cast', my_vars[2].name)
-
- with ops.Graph().as_default() as g:
- features = {
- 'price_a': [[1.]],
- 'price_b': [[3.]],
- 'wire_cast':
- sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- }
- get_keras_linear_model_predictions(
- features, [wire_cast, price_b, price_a],
- weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- self.assertIn('price_a', my_vars[0].name)
- self.assertIn('price_b', my_vars[1].name)
- self.assertIn('wire_cast', my_vars[2].name)
-
- def test_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {
- 'price1': [[1.], [5.], [7.]], # batchsize = 3
- 'price2': [[3.], [4.]] # batchsize = 2
- }
- with self.assertRaisesRegexp(
- ValueError,
- 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- get_keras_linear_model_predictions(features, [price1, price2])
-
- def test_subset_of_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- price3 = fc_old.numeric_column('price3')
- with ops.Graph().as_default():
- features = {
- 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
- 'price2': [[3.], [4.]], # batchsize = 2
- 'price3': [[3.], [4.], [5.]] # batchsize = 3
- }
- with self.assertRaisesRegexp(
- ValueError,
- 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- get_keras_linear_model_predictions(features, [price1, price2, price3])
-
- def test_runtime_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {
- 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
- 'price2': [[3.], [4.]] # batchsize = 2
- }
- predictions = get_keras_linear_model_predictions(features,
- [price1, price2])
- with _initialized_session() as sess:
- with self.assertRaisesRegexp(errors.OpError,
- 'must have the same size and shape'):
- sess.run(
- predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]})
-
- def test_runtime_batch_size_matches(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {
- 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
- 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
- }
- predictions = get_keras_linear_model_predictions(features,
- [price1, price2])
- with _initialized_session() as sess:
- sess.run(
- predictions,
- feed_dict={
- features['price1']: [[1.], [5.]],
- features['price2']: [[1.], [5.]],
- })
-
- def test_with_numpy_input_fn(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
- price, boundaries=[
- 0.,
- 10.,
- 100.,
- ])
- body_style = fc_old.categorical_column_with_vocabulary_list(
- 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
-
- input_fn = numpy_io.numpy_input_fn(
- x={
- 'price': np.array([-1., 2., 13., 104.]),
- 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
- },
- batch_size=2,
- shuffle=False)
- features = input_fn()
- net = get_keras_linear_model_predictions(features,
- [price_buckets, body_style])
- # self.assertEqual(1 + 3 + 5, net.shape[1])
- with _initialized_session() as sess:
- coord = coordinator.Coordinator()
- threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
-
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
-
- sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
- sess.run(bias.assign([5.]))
-
- self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
-
- coord.request_stop()
- coord.join(threads)
-
- def test_with_1d_sparse_tensor(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
- price, boundaries=[
- 0.,
- 10.,
- 100.,
- ])
- body_style = fc_old.categorical_column_with_vocabulary_list(
- 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
-
- # Provides 1-dim tensor and dense tensor.
- features = {
- 'price':
- constant_op.constant([
- -1.,
- 12.,
- ]),
- 'body-style':
- sparse_tensor.SparseTensor(
- indices=((0,), (1,)),
- values=('sedan', 'hardtop'),
- dense_shape=(2,)),
- }
- self.assertEqual(1, features['price'].shape.ndims)
- self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
-
- net = get_keras_linear_model_predictions(features,
- [price_buckets, body_style])
- with _initialized_session() as sess:
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
-
- sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
- sess.run(bias.assign([5.]))
-
- self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
-
- def test_with_1d_unknown_shape_sparse_tensor(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
- price, boundaries=[
- 0.,
- 10.,
- 100.,
- ])
- body_style = fc_old.categorical_column_with_vocabulary_list(
- 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
- country = fc_old.categorical_column_with_vocabulary_list(
- 'country', vocabulary_list=['US', 'JP', 'CA'])
-
- # Provides 1-dim tensor and dense tensor.
- features = {
- 'price': array_ops.placeholder(dtypes.float32),
- 'body-style': array_ops.sparse_placeholder(dtypes.string),
- 'country': array_ops.placeholder(dtypes.string),
- }
- self.assertIsNone(features['price'].shape.ndims)
- self.assertIsNone(features['body-style'].get_shape().ndims)
-
- price_data = np.array([-1., 12.])
- body_style_data = sparse_tensor.SparseTensorValue(
- indices=((0,), (1,)), values=('sedan', 'hardtop'), dense_shape=(2,))
- country_data = np.array(['US', 'CA'])
-
- net = get_keras_linear_model_predictions(
- features, [price_buckets, body_style, country])
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
- with _initialized_session() as sess:
- sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
- sess.run(bias.assign([5.]))
-
- self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
- sess.run(
- net,
- feed_dict={
- features['price']: price_data,
- features['body-style']: body_style_data,
- features['country']: country_data
- }))
-
- def test_with_rank_0_feature(self):
- price = fc_old.numeric_column('price')
- features = {
- 'price': constant_op.constant(0),
- }
- self.assertEqual(0, features['price'].shape.ndims)
-
- # Static rank 0 should fail
- with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
- get_keras_linear_model_predictions(features, [price])
-
- # Dynamic rank 0 should fail
- features = {
- 'price': array_ops.placeholder(dtypes.float32),
- }
- net = get_keras_linear_model_predictions(features, [price])
- self.assertEqual(1, net.shape[1])
- with _initialized_session() as sess:
- with self.assertRaisesOpError('Feature .* cannot have rank 0'):
- sess.run(net, feed_dict={features['price']: np.array(1)})
-
-
class FeatureLayerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
@@ -3723,47 +2834,22 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
id_weight_pair.id_tensor.eval())
def test_linear_model(self):
- wire_column = fc_old.categorical_column_with_vocabulary_file(
- key='wire',
- vocabulary_file=self._wire_vocabulary_file_name,
- vocabulary_size=self._wire_vocabulary_size,
- num_oov_buckets=1)
- self.assertEqual(4, wire_column._num_buckets)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- wire_column.name: sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
- # 'marlo' -> 2: wire_var[2] = 3
- # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
- self.assertAllClose(((3.,), (5.,)), predictions.eval())
-
- def test_keras_linear_model(self):
- wire_column = fc_old.categorical_column_with_vocabulary_file(
+ wire_column = fc.categorical_column_with_vocabulary_file(
key='wire',
vocabulary_file=self._wire_vocabulary_file_name,
vocabulary_size=self._wire_vocabulary_size,
num_oov_buckets=1)
- self.assertEqual(4, wire_column._num_buckets)
+ self.assertEqual(4, wire_column.num_buckets)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((wire_column,))
+ predictions = model({
wire_column.name:
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
+ })
+ wire_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
@@ -4124,45 +3210,21 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
id_weight_pair.id_tensor.eval())
def test_linear_model(self):
- wire_column = fc_old.categorical_column_with_vocabulary_list(
- key='aaa',
- vocabulary_list=('omar', 'stringer', 'marlo'),
- num_oov_buckets=1)
- self.assertEqual(4, wire_column._num_buckets)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- wire_column.name: sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
- # 'marlo' -> 2: wire_var[2] = 3
- # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
- self.assertAllClose(((3.,), (5.,)), predictions.eval())
-
- def test_keras_linear_model(self):
- wire_column = fc_old.categorical_column_with_vocabulary_list(
+ wire_column = fc.categorical_column_with_vocabulary_list(
key='aaa',
vocabulary_list=('omar', 'stringer', 'marlo'),
num_oov_buckets=1)
- self.assertEqual(4, wire_column._num_buckets)
+ self.assertEqual(4, wire_column.num_buckets)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((wire_column,))
+ predictions = model({
wire_column.name:
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
+ })
+ wire_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
@@ -4382,39 +3444,18 @@ class IdentityCategoricalColumnTest(test.TestCase):
}))
def test_linear_model(self):
- column = fc_old.categorical_column_with_identity(key='aaa', num_buckets=3)
- self.assertEqual(3, column.num_buckets)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- column.name: sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- weight_var.assign(((1.,), (2.,), (3.,))).eval()
- # weight_var[0] = 1
- # weight_var[2] + weight_var[1] = 3+2 = 5
- self.assertAllClose(((1.,), (5.,)), predictions.eval())
-
- def test_keras_linear_model(self):
- column = fc_old.categorical_column_with_identity(key='aaa', num_buckets=3)
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
self.assertEqual(3, column.num_buckets)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((column,))
+ predictions = model({
column.name:
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=(0, 2, 1),
dense_shape=(2, 2))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
+ })
+ weight_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
@@ -4640,27 +3681,8 @@ class IndicatorColumnTest(test.TestCase):
self.assertAllEqual([[0., 1., 1.]], indicator_tensor.eval())
def test_linear_model(self):
- animal = fc_old.indicator_column(
- fc_old.categorical_column_with_identity('animal', num_buckets=4))
- with ops.Graph().as_default():
- features = {
- 'animal':
- sparse_tensor.SparseTensor(
- indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
- }
-
- predictions = fc.linear_model(features, [animal])
- weight_var = get_linear_model_column_var(animal)
- with _initialized_session():
- # All should be zero-initialized.
- self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval())
- self.assertAllClose([[0.]], predictions.eval())
- weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
- self.assertAllClose([[2. + 3.]], predictions.eval())
-
- def test_keras_linear_model(self):
- animal = fc_old.indicator_column(
- fc_old.categorical_column_with_identity('animal', num_buckets=4))
+ animal = fc.indicator_column(
+ fc.categorical_column_with_identity('animal', num_buckets=4))
with ops.Graph().as_default():
features = {
'animal':
@@ -4668,8 +3690,9 @@ class IndicatorColumnTest(test.TestCase):
indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
}
- predictions = get_keras_linear_model_predictions(features, [animal])
- weight_var = get_linear_model_column_var(animal)
+ model = fc.LinearModel([animal])
+ predictions = model(features)
+ weight_var, _ = model.variables
with _initialized_session():
# All should be zero-initialized.
self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval())
@@ -5121,17 +4144,16 @@ class EmbeddingColumnTest(test.TestCase):
return zeros_embedding_values
# Build columns.
- categorical_column = fc_old.categorical_column_with_identity(
+ categorical_column = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- embedding_column = fc_old.embedding_column(
+ embedding_column = fc.embedding_column(
categorical_column,
dimension=embedding_dimension,
initializer=_initializer)
with ops.Graph().as_default():
- predictions = fc.linear_model({
- categorical_column.name: sparse_input
- }, (embedding_column,))
+ model = fc.LinearModel((embedding_column,))
+ predictions = model({categorical_column.name: sparse_input})
expected_var_names = (
'linear_model/bias_weights:0',
'linear_model/aaa_embedding/weights:0',
@@ -5173,82 +4195,6 @@ class EmbeddingColumnTest(test.TestCase):
# = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
- def test_keras_linear_model(self):
- # Inputs.
- batch_size = 4
- vocabulary_size = 3
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- # example 2, ids []
- # example 3, ids [1]
- indices=((0, 0), (1, 0), (1, 4), (3, 0)),
- values=(2, 0, 1, 1),
- dense_shape=(batch_size, 5))
-
- # Embedding variable.
- embedding_dimension = 2
- embedding_shape = (vocabulary_size, embedding_dimension)
- zeros_embedding_values = np.zeros(embedding_shape)
-
- def _initializer(shape, dtype, partition_info):
- self.assertAllEqual(embedding_shape, shape)
- self.assertEqual(dtypes.float32, dtype)
- self.assertIsNone(partition_info)
- return zeros_embedding_values
-
- # Build columns.
- categorical_column = fc_old.categorical_column_with_identity(
- key='aaa', num_buckets=vocabulary_size)
- embedding_column = fc_old.embedding_column(
- categorical_column,
- dimension=embedding_dimension,
- initializer=_initializer)
-
- with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
- categorical_column.name: sparse_input
- }, (embedding_column,))
- expected_var_names = (
- 'linear_model/bias_weights:0',
- 'linear_model/aaa_embedding/weights:0',
- 'linear_model/aaa_embedding/embedding_weights:0',
- )
- self.assertItemsEqual(
- expected_var_names,
- [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
- trainable_vars = {
- v.name: v
- for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- }
- self.assertItemsEqual(expected_var_names, trainable_vars.keys())
- bias = trainable_vars['linear_model/bias_weights:0']
- embedding_weights = trainable_vars[
- 'linear_model/aaa_embedding/embedding_weights:0']
- linear_weights = trainable_vars['linear_model/aaa_embedding/weights:0']
- with _initialized_session():
- # Predictions with all zero weights.
- self.assertAllClose(np.zeros((1,)), bias.eval())
- self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
- self.assertAllClose(
- np.zeros((embedding_dimension, 1)), linear_weights.eval())
- self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
-
- # Predictions with all non-zero weights.
- embedding_weights.assign((
- (1., 2.), # id 0
- (3., 5.), # id 1
- (7., 11.) # id 2
- )).eval()
- linear_weights.assign(((4.,), (6.,))).eval()
- # example 0, ids [2], embedding[0] = [7, 11]
- # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
- # example 2, ids [], embedding[2] = [0, 0]
- # example 3, ids [1], embedding[3] = [3, 5]
- # sum(embeddings * linear_weights)
- # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
- self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
-
def test_feature_layer(self):
# Inputs.
vocabulary_size = 3
@@ -5749,27 +4695,31 @@ class SharedEmbeddingColumnTest(test.TestCase):
return zeros_embedding_values
# Build columns.
- categorical_column_a = fc_old.categorical_column_with_identity(
+ categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- categorical_column_b = fc_old.categorical_column_with_identity(
+ categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns(
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
initializer=_initializer)
with ops.Graph().as_default():
- predictions = fc.linear_model({
+ model = fc.LinearModel(
+ (embedding_column_a, embedding_column_b),
+ shared_state_manager=fc.SharedEmbeddingStateManager())
+ predictions = model({
categorical_column_a.name: input_a,
- categorical_column_b.name: input_b,
- }, (embedding_column_a, embedding_column_b))
+ categorical_column_b.name: input_b
+ })
+
# Linear weights do not follow the column name. But this is a rare use
# case, and fixing it would add too much complexity to the code.
expected_var_names = (
'linear_model/bias_weights:0',
- 'linear_model/aaa_bbb_shared_embedding/weights:0',
- 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
- 'linear_model/aaa_bbb_shared_embedding_1/weights:0',
+ 'linear_model/aaa_shared_embedding/weights:0',
+ 'shared_embedding_state_manager/aaa_bbb_shared_embedding:0',
+ 'linear_model/bbb_shared_embedding/weights:0',
)
self.assertItemsEqual(
expected_var_names,
@@ -5781,102 +4731,11 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertItemsEqual(expected_var_names, trainable_vars.keys())
bias = trainable_vars['linear_model/bias_weights:0']
embedding_weights = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
+ 'shared_embedding_state_manager/aaa_bbb_shared_embedding:0']
linear_weights_a = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding/weights:0']
+ 'linear_model/aaa_shared_embedding/weights:0']
linear_weights_b = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding_1/weights:0']
- with _initialized_session():
- # Predictions with all zero weights.
- self.assertAllClose(np.zeros((1,)), bias.eval())
- self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
- self.assertAllClose(
- np.zeros((embedding_dimension, 1)), linear_weights_a.eval())
- self.assertAllClose(
- np.zeros((embedding_dimension, 1)), linear_weights_b.eval())
- self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
-
- # Predictions with all non-zero weights.
- embedding_weights.assign((
- (1., 2.), # id 0
- (3., 5.), # id 1
- (7., 11.) # id 2
- )).eval()
- linear_weights_a.assign(((4.,), (6.,))).eval()
- # example 0, ids [2], embedding[0] = [7, 11]
- # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
- # sum(embeddings * linear_weights)
- # = [4*7 + 6*11, 4*2 + 6*3.5] = [94, 29]
- linear_weights_b.assign(((3.,), (5.,))).eval()
- # example 0, ids [0], embedding[0] = [1, 2]
- # example 1, ids [], embedding[1] = 0, 0]
- # sum(embeddings * linear_weights)
- # = [3*1 + 5*2, 3*0 +5*0] = [13, 0]
- self.assertAllClose([[94. + 13.], [29.]], predictions.eval())
-
- def test_keras_linear_model(self):
- # Inputs.
- batch_size = 2
- vocabulary_size = 3
- # -1 values are ignored.
- input_a = np.array([
- [2, -1, -1], # example 0, ids [2]
- [0, 1, -1]
- ]) # example 1, ids [0, 1]
- input_b = np.array([
- [0, -1, -1], # example 0, ids [0]
- [-1, -1, -1]
- ]) # example 1, ids []
-
- # Embedding variable.
- embedding_dimension = 2
- embedding_shape = (vocabulary_size, embedding_dimension)
- zeros_embedding_values = np.zeros(embedding_shape)
-
- def _initializer(shape, dtype, partition_info):
- self.assertAllEqual(embedding_shape, shape)
- self.assertEqual(dtypes.float32, dtype)
- self.assertIsNone(partition_info)
- return zeros_embedding_values
-
- # Build columns.
- categorical_column_a = fc_old.categorical_column_with_identity(
- key='aaa', num_buckets=vocabulary_size)
- categorical_column_b = fc_old.categorical_column_with_identity(
- key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns(
- [categorical_column_a, categorical_column_b],
- dimension=embedding_dimension,
- initializer=_initializer)
-
- with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
- categorical_column_a.name: input_a,
- categorical_column_b.name: input_b,
- }, (embedding_column_a, embedding_column_b))
- # Linear weights do not follow the column name. But this is a rare use
- # case, and fixing it would add too much complexity to the code.
- expected_var_names = (
- 'linear_model/bias_weights:0',
- 'linear_model/aaa_bbb_shared_embedding/weights:0',
- 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
- 'linear_model/aaa_bbb_shared_embedding_1/weights:0',
- )
- self.assertItemsEqual(
- expected_var_names,
- [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
- trainable_vars = {
- v.name: v
- for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- }
- self.assertItemsEqual(expected_var_names, trainable_vars.keys())
- bias = trainable_vars['linear_model/bias_weights:0']
- embedding_weights = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
- linear_weights_a = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding/weights:0']
- linear_weights_b = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding_1/weights:0']
+ 'linear_model/bbb_shared_embedding/weights:0']
with _initialized_session():
# Predictions with all zero weights.
self.assertAllClose(np.zeros((1,)), bias.eval())
@@ -6275,13 +5134,14 @@ class WeightedCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2)),
weight_tensor.eval())
- def test_keras_linear_model(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
+ def test_linear_model(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((column,))
+ predictions = model({
'ids':
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
@@ -6292,9 +5152,8 @@ class WeightedCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=(.5, 1., .1),
dense_shape=(2, 2))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
+ })
+ weight_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
@@ -6305,15 +5164,16 @@ class WeightedCategoricalColumnTest(test.TestCase):
# = 3*1 + 2*.1 = 3+.2 = 3.2
self.assertAllClose(((.5,), (3.2,)), predictions.eval())
- def test_keras_linear_model_mismatched_shape(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
+ def test_linear_model_mismatched_shape(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- with self.assertRaisesRegexp(ValueError,
- r'Dimensions.*are not compatible'):
- get_keras_linear_model_predictions({
+ with self.assertRaisesRegexp(
+ ValueError, r'Dimensions.*are not compatible'):
+ model = fc.LinearModel((column,))
+ model({
'ids':
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
@@ -6324,122 +5184,23 @@ class WeightedCategoricalColumnTest(test.TestCase):
indices=((0, 0), (0, 1), (1, 0), (1, 1)),
values=(.5, 11., 1., .1),
dense_shape=(2, 2))
- }, (column,))
+ })
- def test_keras_linear_model_mismatched_dense_values(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
- key='ids', num_buckets=3),
- weight_feature_key='values')
- with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions(
- {
- 'ids':
- sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': ((.5,), (1.,))
- }, (column,),
- sparse_combiner='mean')
- # Disabling the constant folding optimizer here since it changes the
- # error message differently on CPU and GPU.
- config = config_pb2.ConfigProto()
- config.graph_options.rewrite_options.constant_folding = (
- rewriter_config_pb2.RewriterConfig.OFF)
- with _initialized_session(config):
- with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
- predictions.eval()
-
- def test_keras_linear_model_mismatched_dense_shape(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
+ def test_linear_model_mismatched_dense_values(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((column,), sparse_combiner='mean')
+ predictions = model({
'ids':
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=(0, 2, 1),
dense_shape=(2, 2)),
- 'values': ((.5,), (1.,), (.1,))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- weight_var.assign(((1.,), (2.,), (3.,))).eval()
- # weight_var[0] * weights[0, 0] = 1 * .5 = .5
- # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
- # = 3*1 + 2*.1 = 3+.2 = 3.2
- self.assertAllClose(((.5,), (3.2,)), predictions.eval())
-
- def test_linear_model(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
- key='ids', num_buckets=3),
- weight_feature_key='values')
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- 'ids': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(.5, 1., .1),
- dense_shape=(2, 2))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- weight_var.assign(((1.,), (2.,), (3.,))).eval()
- # weight_var[0] * weights[0, 0] = 1 * .5 = .5
- # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
- # = 3*1 + 2*.1 = 3+.2 = 3.2
- self.assertAllClose(((.5,), (3.2,)), predictions.eval())
-
- def test_linear_model_mismatched_shape(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
- key='ids', num_buckets=3),
- weight_feature_key='values')
- with ops.Graph().as_default():
- with self.assertRaisesRegexp(
- ValueError, r'Dimensions.*are not compatible'):
- fc.linear_model({
- 'ids': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (0, 1), (1, 0), (1, 1)),
- values=(.5, 11., 1., .1),
- dense_shape=(2, 2))
- }, (column,))
-
- def test_linear_model_mismatched_dense_values(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
- key='ids', num_buckets=3),
- weight_feature_key='values')
- with ops.Graph().as_default():
- predictions = fc.linear_model(
- {
- 'ids':
- sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': ((.5,), (1.,))
- }, (column,),
- sparse_combiner='mean')
+ 'values': ((.5,), (1.,))
+ })
# Disabling the constant folding optimizer here since it changes the
# error message differently on CPU and GPU.
config = config_pb2.ConfigProto()
@@ -6450,20 +5211,21 @@ class WeightedCategoricalColumnTest(test.TestCase):
predictions.eval()
def test_linear_model_mismatched_dense_shape(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- predictions = fc.linear_model({
- 'ids': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
+ model = fc.LinearModel((column,))
+ predictions = model({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
'values': ((.5,), (1.,), (.1,))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
+ })
+ weight_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index a8aef3a009..f287289bd0 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -762,13 +762,12 @@ class _FuncGraph(ops.Graph):
if handle_data:
handle_data = handle_data.SerializeToString()
else:
- handle_data = c_api.GetResourceHandleShapeAndType(
- tensor.graph._c_graph, tensor._as_tf_output())
+ handle_data = c_api.GetHandleShapeAndType(tensor.graph._c_graph,
+ tensor._as_tf_output())
if handle_data:
- c_api.SetResourceHandleShapeAndType(ph.graph._c_graph,
- ph._as_tf_output(),
- compat.as_bytes(handle_data))
+ c_api.SetHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(),
+ compat.as_bytes(handle_data))
else:
ph._handle_data = tensor._handle_data
# pylint: enable=protected-access
@@ -1097,6 +1096,21 @@ def _from_library(lib):
return initialized.values()
+def _get_experimental_kwarg_as_attr(attr_name, value):
+ """Creates an AttrValue for a python object."""
+ if isinstance(value, bool):
+ return attr_value_pb2.AttrValue(b=value)
+ elif isinstance(value, int):
+ return attr_value_pb2.AttrValue(i=value)
+ elif isinstance(value, float):
+ return attr_value_pb2.AttrValue(f=value)
+ elif isinstance(value, str):
+ return attr_value_pb2.AttrValue(s=compat.as_bytes(value))
+ else:
+ raise ValueError("Unsupported attribute type for %s with type %s" %
+ (attr_name, type(value)))
+
+
def _parse_kwargs_as_attrs(func_name, **kwargs):
"""Parses **kwargs into a node's attributes."""
attrs = {}
@@ -1123,7 +1137,7 @@ def _parse_kwargs_as_attrs(func_name, **kwargs):
kwargs_keys = list(kwargs.keys())
for key in kwargs_keys:
if key.startswith("experimental_"):
- attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(kwargs[key]))
+ attrs[key] = _get_experimental_kwarg_as_attr(key, kwargs[key])
del kwargs[key]
if kwargs:
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 903768a039..87f567db0e 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -113,7 +113,7 @@ class FunctionTest(test.TestCase):
return a
with ops.Graph().as_default():
- var = variables.Variable([18.0])
+ var = variables.VariableV1([18.0])
call = MyIdentityFunc(var._ref()) # pylint: disable=protected-access
self.assertEqual("MyIdentity", call.op.name)
for cfg in _OptimizerOptions():
@@ -1331,12 +1331,33 @@ class FunctionsFromProtos(test.TestCase):
def testExperimentalAttrs(self):
@function.Defun(dtypes.int32, experimental_tag="tag_value")
- def FunctionWithAttr(i):
+ def FunctionWithStrAttr(i):
return array_ops.identity(i)
- self.assertTrue("experimental_tag" in FunctionWithAttr.definition.attr)
- self.assertEqual(FunctionWithAttr.definition.attr["experimental_tag"].s,
+ @function.Defun(dtypes.int32, experimental_tag=123)
+ def FunctionWithIntAttr(i):
+ return array_ops.identity(i)
+
+ @function.Defun(dtypes.int32, experimental_tag=123.0)
+ def FunctionWithFloatAttr(i):
+ return array_ops.identity(i)
+
+ @function.Defun(dtypes.int32, experimental_tag=True)
+ def FunctionWithBoolAttr(i):
+ return array_ops.identity(i)
+
+ self.assertTrue("experimental_tag" in FunctionWithStrAttr.definition.attr)
+ self.assertEqual(FunctionWithStrAttr.definition.attr["experimental_tag"].s,
b"tag_value")
+ self.assertTrue("experimental_tag" in FunctionWithIntAttr.definition.attr)
+ self.assertEqual(FunctionWithIntAttr.definition.attr["experimental_tag"].i,
+ 123)
+ self.assertTrue("experimental_tag" in FunctionWithFloatAttr.definition.attr)
+ self.assertEqual(
+ FunctionWithFloatAttr.definition.attr["experimental_tag"].f, 123.0)
+ self.assertTrue("experimental_tag" in FunctionWithBoolAttr.definition.attr)
+ self.assertEqual(FunctionWithBoolAttr.definition.attr["experimental_tag"].b,
+ True)
@test_util.with_c_shapes
diff --git a/tensorflow/python/framework/graph_util_test.py b/tensorflow/python/framework/graph_util_test.py
index 2dafb94ba7..563a177dd0 100644
--- a/tensorflow/python/framework/graph_util_test.py
+++ b/tensorflow/python/framework/graph_util_test.py
@@ -104,13 +104,13 @@ class DeviceFunctionsTest(test.TestCase):
def testNestedDeviceFunctions(self):
with ops.Graph().as_default():
- var_0 = variables.Variable(0)
+ var_0 = variables.VariableV1(0)
with ops.device(test_device_func_pin_variable_to_cpu):
- var_1 = variables.Variable(1)
+ var_1 = variables.VariableV1(1)
with ops.device(lambda op: "/device:GPU:0"):
- var_2 = variables.Variable(2)
+ var_2 = variables.VariableV1(2)
with ops.device("/device:GPU:0"): # Implicit merging device function.
- var_3 = variables.Variable(3)
+ var_3 = variables.VariableV1(3)
self.assertDeviceEqual(var_0.device, None)
self.assertDeviceEqual(var_1.device, "/device:CPU:0")
diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py
index 535c6017f5..908a5f521e 100644
--- a/tensorflow/python/framework/load_library.py
+++ b/tensorflow/python/framework/load_library.py
@@ -18,14 +18,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import errno
import hashlib
import imp
+import os
+import platform
import sys
import threading # pylint: disable=unused-import
from tensorflow.core.framework import op_def_pb2
from tensorflow.core.lib.core import error_codes_pb2 # pylint: disable=unused-import
from tensorflow.python import pywrap_tensorflow as py_tf
+from tensorflow.python.lib.io import file_io
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
@@ -98,3 +102,64 @@ def load_file_system_library(library_filename):
RuntimeError: when unable to load the library.
"""
py_tf.TF_LoadLibrary(library_filename)
+
+
+def _is_shared_object(filename):
+ """Check the file to see if it is a shared object, only using extension."""
+ if platform.system() == 'Linux':
+ if filename.endswith('.so'):
+ return True
+ else:
+ index = filename.rfind('.so.')
+ if index == -1:
+ return False
+ else:
+ # A shared object with the API version in filename
+ return filename[index + 4].isdecimal()
+ elif platform.system() == 'Darwin':
+ return filename.endswith('.dylib')
+ elif platform.system() == 'Windows':
+ return filename.endswith('.dll')
+ else:
+ return False
+
+
+@tf_export('load_library')
+def load_library(library_location):
+ """Loads a TensorFlow plugin.
+
+ "library_location" can be a path to a specific shared object, or a folder.
+ If it is a folder, all sahred objects that are named "libtfkernel*" will be
+ loaded. When the library is loaded, kernels registered in the library via the
+ `REGISTER_*` macros are made available in the TensorFlow process.
+
+ Args:
+ library_location: Path to the plugin or the folder of plugins.
+ Relative or absolute filesystem path to a dynamic library file or folder.
+
+ Returns:
+ None
+
+ Raises:
+ OSError: When the file to be loaded is not found.
+ RuntimeError: when unable to load the library.
+ """
+ if file_io.file_exists(library_location):
+ if file_io.is_directory(library_location):
+ directory_contents = file_io.list_directory(library_location)
+
+ kernel_libraries = [
+ os.path.join(library_location, f) for f in directory_contents
+ if _is_shared_object(f)]
+ else:
+ kernel_libraries = [library_location]
+
+ for lib in kernel_libraries:
+ py_tf.TF_LoadLibrary(lib)
+
+ else:
+ raise OSError(
+ errno.ENOENT,
+ 'The file or folder to load kernel libraries from does not exist.',
+ library_location)
+
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 343f52fe8f..8bb177939e 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -2532,8 +2532,8 @@ def _set_shape_and_handle_data_for_outputs_c_api(op):
output._shape_val = output._c_api_shape()
# Set the resource handle data for compatibility with the Python shape
# inference code.
- serialized = c_api.GetResourceHandleShapeAndType(op._graph._c_graph,
- output._as_tf_output())
+ serialized = c_api.GetHandleShapeAndType(op._graph._c_graph, # pylint: disable=protected-access
+ output._as_tf_output())
if serialized:
output._handle_data = (
cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index d59adf3d48..c3a3437743 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -2142,8 +2142,8 @@ class InitScopeTest(test_util.TensorFlowTestCase):
def function_with_variables():
with ops.init_scope():
- v = resource_variable_ops.ResourceVariable(3)
- return v.assign_add(1)
+ self.v = resource_variable_ops.ResourceVariable(3)
+ return self.v.assign_add(1)
with context.eager_mode():
# Each invocation of function_with_variables recreates a variable.
@@ -2188,13 +2188,13 @@ class InitScopeTest(test_util.TensorFlowTestCase):
def inner_function():
with ops.init_scope():
- v = resource_variable_ops.ResourceVariable(1)
- return v.assign_add(2)
+ self.v = resource_variable_ops.ResourceVariable(1)
+ return self.v.assign_add(2)
def outer_function(inner=None):
with ops.init_scope():
- v0 = resource_variable_ops.ResourceVariable(0)
- return v0.assign_add(1) + inner()
+ self.v0 = resource_variable_ops.ResourceVariable(0)
+ return self.v0.assign_add(1) + inner()
with context.eager_mode():
# Each invocation of outer_function recreates variables.
diff --git a/tensorflow/python/framework/subscribe_test.py b/tensorflow/python/framework/subscribe_test.py
index 1d594e4078..cab426844d 100644
--- a/tensorflow/python/framework/subscribe_test.py
+++ b/tensorflow/python/framework/subscribe_test.py
@@ -212,8 +212,8 @@ class SubscribeTest(test_util.TensorFlowTestCase):
def testSubscribeVariable(self):
"""Confirm that variables can be subscribed."""
- v1 = variables.Variable(0.0)
- v2 = variables.Variable(4.0)
+ v1 = variables.VariableV1(0.0)
+ v2 = variables.VariableV1(4.0)
add = math_ops.add(v1, v2)
assign_v1 = v1.assign(3.0)
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index b7398238f5..6673bc5561 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -25,6 +25,7 @@ import contextlib
import gc
import itertools
import math
+import os
import random
import re
import tempfile
@@ -401,11 +402,14 @@ def with_c_shapes(cls):
return cls
-def enable_cond_v2(fn):
- """Decorator for enabling CondV2 on a test.
+def enable_control_flow_v2(fn):
+ """Decorator for enabling CondV2 and WhileV2 on a test.
- Note this enables using CondV2 after running the test class's setup/teardown
- methods.
+ Note this enables using CondV2 and WhileV2 after running the test class's
+ setup/teardown methods.
+
+ In addition to this, callers must import the while_v2 module in order to set
+ the _while_v2 module in control_flow_ops.
Args:
fn: the function to be wrapped
@@ -415,21 +419,56 @@ def enable_cond_v2(fn):
"""
def wrapper(*args, **kwargs):
- prev_value = control_flow_ops.ENABLE_COND_V2
+ enable_cond_v2_old = control_flow_ops.ENABLE_COND_V2
+ enable_while_v2_old = control_flow_ops.ENABLE_WHILE_V2
control_flow_ops.ENABLE_COND_V2 = True
+ control_flow_ops.ENABLE_WHILE_V2 = True
try:
fn(*args, **kwargs)
finally:
- control_flow_ops.ENABLE_COND_V2 = prev_value
+ control_flow_ops.ENABLE_COND_V2 = enable_cond_v2_old
+ control_flow_ops.ENABLE_WHILE_V2 = enable_while_v2_old
return wrapper
-def with_cond_v2(cls):
- """Adds methods that call original methods but with CondV2 enabled.
+def with_control_flow_v2(cls):
+ """Adds methods that call original methods with WhileV2 and CondV2 enabled.
- Note this enables CondV2 in new methods after running the test class's
- setup method.
+ Note this enables CondV2 and WhileV2 in new methods after running the test
+ class's setup method.
+
+ In addition to this, callers must import the while_v2 module in order to set
+ the _while_v2 module in control_flow_ops.
+
+ If a test function has _disable_control_flow_v2 attr set to True (using the
+ @disable_control_flow_v2 decorator), the v2 function is not generated for it.
+
+ Example:
+
+ @test_util.with_control_flow_v2
+ class ControlFlowTest(test.TestCase):
+
+ def testEnabledForV2(self):
+ ...
+
+ @test_util.disable_control_flow_v2("b/xyzabc")
+ def testDisabledForV2(self):
+ ...
+
+ Generated class:
+ class ControlFlowTest(test.TestCase):
+
+ def testEnabledForV2(self):
+ ...
+
+ def testEnabledForV2WithControlFlowV2(self):
+ // Enable V2 flags.
+ testEnabledForV2(self)
+ // Restore V2 flags.
+
+ def testDisabledForV2(self):
+ ...
Args:
cls: class to decorate
@@ -437,15 +476,33 @@ def with_cond_v2(cls):
Returns:
cls with new test methods added
"""
- if control_flow_ops.ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_WHILE_V2 and control_flow_ops.ENABLE_COND_V2:
return cls
for name, value in cls.__dict__.copy().items():
- if callable(value) and name.startswith("test"):
- setattr(cls, name + "WithCondV2", enable_cond_v2(value))
+ if (callable(value) and name.startswith("test") and
+ not getattr(value, "_disable_control_flow_v2", False)):
+ setattr(cls, name + "WithControlFlowV2", enable_control_flow_v2(value))
return cls
+def disable_control_flow_v2(unused_msg):
+ """Decorator for a function in a with_control_flow_v2 enabled test class.
+
+ Blocks the function from being run with v2 control flow ops.
+
+ Args:
+ unused_msg: Reason for disabling.
+
+ Returns:
+ The wrapped function with _disable_control_flow_v2 attr set to True.
+ """
+ def wrapper(func):
+ func._disable_control_flow_v2 = True
+ return func
+ return wrapper
+
+
def assert_no_new_pyobjects_executing_eagerly(f):
"""Decorator for asserting that no new Python objects persist after a test.
@@ -868,6 +925,19 @@ def device(use_gpu):
yield
+class CapturedWrites(object):
+ """A utility class to load the captured writes made to a stream."""
+
+ def __init__(self, capture_location):
+ self.capture_location = capture_location
+
+ def contents(self):
+ """Get the captured writes as a single string."""
+ with open(self.capture_location) as tmp_file:
+ output_data = "".join(tmp_file.readlines())
+ return output_data
+
+
class ErrorLoggingSession(session.Session):
"""Wrapper around a Session that logs errors in run().
"""
@@ -876,7 +946,11 @@ class ErrorLoggingSession(session.Session):
try:
return super(ErrorLoggingSession, self).run(*args, **kwargs)
except Exception as e: # pylint: disable=broad-except
- logging.error(str(e))
+ # Note: disable the logging for OutOfRangeError, which makes the output
+ # of tf.data tests hard to read, because OutOfRangeError is used as the
+ # signal completion
+ if not isinstance(e, errors.OutOfRangeError):
+ logging.error(str(e))
raise
@@ -934,6 +1008,52 @@ class TensorFlowTestCase(googletest.TestCase):
self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
return self._tempdir
+ @contextlib.contextmanager
+ def captureWritesToStream(self, stream):
+ """A context manager that captures the writes to a given stream.
+
+ This context manager captures all writes to a given stream inside of a
+ `CapturedWrites` object. When this context manager is created, it yields
+ the `CapturedWrites` object. The captured contents can be accessed by
+ calling `.contents()` on the `CapturedWrites`.
+
+ For this function to work, the stream must have a file descriptor that
+ can be modified using `os.dup` and `os.dup2`, and the stream must support
+ a `.flush()` method. The default python sys.stdout and sys.stderr are
+ examples of this. Note that this does not work in Colab or Jupyter
+ notebooks, because those use alternate stdout streams.
+
+ Example:
+ ```python
+ class MyOperatorTest(test_util.TensorFlowTestCase):
+ def testMyOperator(self):
+ input = [1.0, 2.0, 3.0, 4.0, 5.0]
+ with self.captureWritesToStream(sys.stdout) as captured:
+ result = MyOperator(input).eval()
+ self.assertStartsWith(captured.contents(), "This was printed.")
+ ```
+
+ Args:
+ stream: The stream whose writes should be captured. This
+ stream must have a file descriptor, support writing via using that
+ file descriptor, and must have a `.flush()` method.
+
+ Yields:
+ A `CapturedWrites` object that contains all writes to the specified stream
+ made during this context.
+ """
+ stream.flush()
+ fd = stream.fileno()
+ tmp_file_path = tempfile.mktemp(dir=self.get_temp_dir())
+ tmp_file = open(tmp_file_path, "w")
+ orig_fd = os.dup(fd)
+ os.dup2(tmp_file.fileno(), fd)
+ try:
+ yield CapturedWrites(tmp_file_path)
+ finally:
+ tmp_file.close()
+ os.dup2(orig_fd, fd)
+
def _AssertProtoEquals(self, a, b, msg=None):
"""Asserts that a and b are the same proto.
@@ -1337,35 +1457,36 @@ class TensorFlowTestCase(googletest.TestCase):
b.shape)
self.assertEqual(a.shape, b.shape, shape_mismatch_msg)
+ msgs = [msg]
if not np.allclose(a, b, rtol=rtol, atol=atol):
- # Prints more details than np.testing.assert_allclose.
+ # Adds more details to np.testing.assert_allclose.
#
# NOTE: numpy.allclose (and numpy.testing.assert_allclose)
# checks whether two arrays are element-wise equal within a
# tolerance. The relative difference (rtol * abs(b)) and the
# absolute difference atol are added together to compare against
# the absolute difference between a and b. Here, we want to
- # print out which elements violate such conditions.
+ # tell user which elements violate such conditions.
cond = np.logical_or(
np.abs(a - b) > atol + rtol * np.abs(b),
np.isnan(a) != np.isnan(b))
if a.ndim:
x = a[np.where(cond)]
y = b[np.where(cond)]
- print("not close where = ", np.where(cond))
+ msgs.append("not close where = {}".format(np.where(cond)))
else:
# np.where is broken for scalars
x, y = a, b
- print("not close lhs = ", x)
- print("not close rhs = ", y)
- print("not close dif = ", np.abs(x - y))
- print("not close tol = ", atol + rtol * np.abs(y))
- print("dtype = %s, shape = %s" % (a.dtype, a.shape))
+ msgs.append("not close lhs = {}".format(x))
+ msgs.append("not close rhs = {}".format(y))
+ msgs.append("not close dif = {}".format(np.abs(x - y)))
+ msgs.append("not close tol = {}".format(atol + rtol * np.abs(y)))
+ msgs.append("dtype = {}, shape = {}".format(a.dtype, a.shape))
# TODO(xpan): There seems to be a bug:
# tensorflow/compiler/tests:binary_ops_test pass with float32
# nan even though the equal_nan is False by default internally.
np.testing.assert_allclose(
- a, b, rtol=rtol, atol=atol, err_msg=msg, equal_nan=True)
+ a, b, rtol=rtol, atol=atol, err_msg="\n".join(msgs), equal_nan=True)
def _assertAllCloseRecursive(self,
a,
@@ -1547,19 +1668,20 @@ class TensorFlowTestCase(googletest.TestCase):
np.float16, np.float32, np.float64, dtypes.bfloat16.as_numpy_dtype
]):
same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b)))
+ msgs = [msg]
if not np.all(same):
- # Prints more details than np.testing.assert_array_equal.
+ # Adds more details to np.testing.assert_array_equal.
diff = np.logical_not(same)
if a.ndim:
x = a[np.where(diff)]
y = b[np.where(diff)]
- print("not equal where = ", np.where(diff))
+ msgs.append("not equal where = {}".format(np.where(diff)))
else:
# np.where is broken for scalars
x, y = a, b
- print("not equal lhs = ", x)
- print("not equal rhs = ", y)
- np.testing.assert_array_equal(a, b, err_msg=msg)
+ msgs.append("not equal lhs = {}".format(x))
+ msgs.append("not equal rhs = {}".format(y))
+ np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs))
def assertAllGreater(self, a, comparison_target):
"""Assert element values are all greater than a target value.
@@ -1874,6 +1996,8 @@ class TensorFlowTestCase(googletest.TestCase):
rewriter_config_pb2.RewriterConfig.OFF)
config.graph_options.rewrite_options.arithmetic_optimization = (
rewriter_config_pb2.RewriterConfig.OFF)
+ config.graph_options.rewrite_options.pin_to_host_optimization = (
+ rewriter_config_pb2.RewriterConfig.OFF)
return config
return ErrorLoggingSession(graph=graph, config=prepare_config(config))
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index c4f8fa9108..22189afa59 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -268,6 +268,11 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertAllClose(7, 7 + 1e-5)
@test_util.run_in_graph_and_eager_modes
+ def testAllCloseList(self):
+ with self.assertRaisesRegexp(AssertionError, r"not close dif"):
+ self.assertAllClose([0], [1])
+
+ @test_util.run_in_graph_and_eager_modes
def testAllCloseDictToNonDict(self):
with self.assertRaisesRegexp(ValueError, r"Can't compare dict to non-dict"):
self.assertAllClose(1, {"a": 1})
@@ -452,6 +457,9 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertAllEqual([120] * 3, k)
self.assertAllEqual([20] * 3, j)
+ with self.assertRaisesRegexp(AssertionError, r"not equal lhs"):
+ self.assertAllEqual([0] * 3, k)
+
@test_util.run_in_graph_and_eager_modes
def testAssertNotAllClose(self):
# Test with arrays
diff --git a/tensorflow/python/grappler/item_test.py b/tensorflow/python/grappler/item_test.py
index c40de9da0a..d3d96c646c 100644
--- a/tensorflow/python/grappler/item_test.py
+++ b/tensorflow/python/grappler/item_test.py
@@ -110,7 +110,7 @@ class ItemTest(test.TestCase):
def testColocationContraints(self):
with ops.Graph().as_default() as g:
c = constant_op.constant([10])
- v = variables.Variable([3], dtype=dtypes.int32)
+ v = variables.VariableV1([3], dtype=dtypes.int32)
i = gen_array_ops.ref_identity(v)
a = state_ops.assign(i, c)
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
diff --git a/tensorflow/python/grappler/memory_optimizer_test.py b/tensorflow/python/grappler/memory_optimizer_test.py
index b658edff2d..03b42f6453 100644
--- a/tensorflow/python/grappler/memory_optimizer_test.py
+++ b/tensorflow/python/grappler/memory_optimizer_test.py
@@ -39,8 +39,8 @@ class MemoryOptimizerSwapTest(test.TestCase):
def testNoSwapping(self):
"""Make sure the graph is preserved when there is nothing to swap."""
- a = variables.Variable(10, name='a')
- b = variables.Variable(20, name='b')
+ a = variables.VariableV1(10, name='a')
+ b = variables.VariableV1(20, name='b')
c = math_ops.add_n([a, b], name='c')
d = math_ops.add_n([b, c], name='d')
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
@@ -60,8 +60,8 @@ class MemoryOptimizerSwapTest(test.TestCase):
def testSimpleSwap(self):
"""Check that the swap annotations are followed."""
- a = variables.Variable(10, name='a')
- b = variables.Variable(20, name='b')
+ a = variables.VariableV1(10, name='a')
+ b = variables.VariableV1(20, name='b')
c = math_ops.add_n([a, b], name='c')
d = math_ops.add_n([b, c], name='d')
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
@@ -244,7 +244,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
init_op_name=init_op_name,
train_op_name=train_op_name,
loss_op_name=loss_op_name)
- self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-4)
+ self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-2)
def _annotated_graph(self):
graph = ops.Graph()
diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py
index 5a9afe7257..eca0f67982 100644
--- a/tensorflow/python/grappler/tf_optimizer_test.py
+++ b/tensorflow/python/grappler/tf_optimizer_test.py
@@ -57,7 +57,7 @@ class PyWrapOptimizeGraphTest(test.TestCase):
def testKeepNodes(self):
g = ops.Graph()
with g.as_default():
- a1 = variables.Variable(
+ a1 = variables.VariableV1(
1.0) # Must be preserved since it's in the collection 'variables'.
a2 = constant_op.constant(0, shape=[50, 50], name='keep')
ops.add_to_collection('a2', a2) # Explicitly add to collection.
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index b521b1430d..4a72c4b3f3 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -381,12 +381,11 @@ py_test(
],
)
-py_test(
+cuda_py_test(
name = "embeddings_test",
size = "medium",
srcs = ["layers/embeddings_test.py"],
- srcs_version = "PY2AND3",
- deps = [
+ additional_deps = [
":keras",
"//tensorflow/python:client_testlib",
],
diff --git a/tensorflow/python/keras/applications/__init__.py b/tensorflow/python/keras/applications/__init__.py
index a8b6d55e41..c35cdb15a4 100644
--- a/tensorflow/python/keras/applications/__init__.py
+++ b/tensorflow/python/keras/applications/__init__.py
@@ -63,7 +63,8 @@ def keras_modules_injection(base_fun):
def wrapper(*args, **kwargs):
if hasattr(keras_applications, 'get_submodules_from_kwargs'):
kwargs['backend'] = backend
- kwargs['layers'] = layers
+ if 'layers' not in kwargs:
+ kwargs['layers'] = layers
kwargs['models'] = models
kwargs['utils'] = utils
return base_fun(*args, **kwargs)
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 5e1722ba20..584facc859 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -367,18 +367,26 @@ def learning_phase():
Returns:
Learning phase (scalar integer tensor or Python integer).
"""
- if context.executing_eagerly():
- if _DUMMY_EAGER_GRAPH not in _GRAPH_LEARNING_PHASES:
- # Fallback to inference mode as default.
- return 0
- return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
+ with ops.init_scope():
+ # We always check & set the learning phase inside the init_scope,
+ # otherwise the wrong default_graph will be used to look up the learning
+ # phase inside of functions & defuns.
+ #
+ # This is because functions & defuns (both in graph & in eager mode)
+ # will always execute non-eagerly using a function-specific default
+ # subgraph.
+ if context.executing_eagerly():
+ if _DUMMY_EAGER_GRAPH not in _GRAPH_LEARNING_PHASES:
+ # Fallback to inference mode as default.
+ return 0
+ return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
- graph = ops.get_default_graph()
- if graph not in _GRAPH_LEARNING_PHASES:
- phase = array_ops.placeholder_with_default(
- False, shape=(), name='keras_learning_phase')
- _GRAPH_LEARNING_PHASES[graph] = phase
- return _GRAPH_LEARNING_PHASES[graph]
+ graph = ops.get_default_graph()
+ if graph not in _GRAPH_LEARNING_PHASES:
+ phase = array_ops.placeholder_with_default(
+ False, shape=(), name='keras_learning_phase')
+ _GRAPH_LEARNING_PHASES[graph] = phase
+ return _GRAPH_LEARNING_PHASES[graph]
@tf_export('keras.backend.set_learning_phase')
@@ -394,10 +402,11 @@ def set_learning_phase(value):
global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned
if value not in {0, 1}:
raise ValueError('Expected learning phase to be 0 or 1.')
- if context.executing_eagerly():
- _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
- else:
- _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value
+ with ops.init_scope():
+ if context.executing_eagerly():
+ _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
+ else:
+ _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value
@tf_contextlib.contextmanager
@@ -423,10 +432,11 @@ def learning_phase_scope(value):
yield value
finally:
# Restore learning phase to initial value.
- if context.executing_eagerly():
- _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value
- else:
- _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value
+ with ops.init_scope():
+ if context.executing_eagerly():
+ _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value
+ else:
+ _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value
@tf_export('keras.backend.get_session')
@@ -685,10 +695,8 @@ def track_tf_optimizer(tf_optimizer):
if context.executing_eagerly():
return
graph = ops.get_default_graph()
- if graph not in _GRAPH_TF_OPTIMIZERS:
- _GRAPH_TF_OPTIMIZERS[graph] = set()
- _GRAPH_TF_OPTIMIZERS[graph].add(tf_optimizer)
-
+ optimizers = _GRAPH_TF_OPTIMIZERS.setdefault(graph, weakref.WeakSet())
+ optimizers.add(tf_optimizer)
def track_variable(v):
"""Tracks the given variable for initialization."""
@@ -696,14 +704,14 @@ def track_variable(v):
return
graph = v.graph if hasattr(v, 'graph') else ops.get_default_graph()
if graph not in _GRAPH_VARIABLES:
- _GRAPH_VARIABLES[graph] = set()
+ _GRAPH_VARIABLES[graph] = weakref.WeakSet()
_GRAPH_VARIABLES[graph].add(v)
def _get_variables(graph=None):
"""Returns variables corresponding to the given graph for initialization."""
assert not context.executing_eagerly()
- variables = _GRAPH_VARIABLES.get(graph, set())
+ variables = _GRAPH_VARIABLES.setdefault(graph, weakref.WeakSet())
for opt in _GRAPH_TF_OPTIMIZERS.get(graph, set()):
variables.update(opt.optimizer.variables())
return variables
@@ -1503,12 +1511,8 @@ def batch_dot(x, y, axes=None):
out = math_ops.reduce_sum(
math_ops.multiply(array_ops.transpose(x, [1, 0]), y), axes[1])
else:
- if axes is not None:
- adj_x = None if axes[0] == ndim(x) - 1 else True
- adj_y = True if axes[1] == ndim(y) - 1 else None
- else:
- adj_x = None
- adj_y = None
+ adj_x = None if axes[0] == ndim(x) - 1 else True
+ adj_y = True if axes[1] == ndim(y) - 1 else None
out = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
if diff:
if x_ndim > y_ndim:
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index befe82f4ec..6dfbbf3694 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -360,7 +360,10 @@ class BaseLogger(Callback):
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', 0)
- self.seen += batch_size
+ # In case of distribution strategy we can potentially run multiple steps
+ # at the same time, we should account for that in the `seen` calculation.
+ num_steps = logs.get('num_steps', 1)
+ self.seen += batch_size * num_steps
for k, v in logs.items():
if k in self.stateful_metrics:
@@ -448,10 +451,13 @@ class ProgbarLogger(Callback):
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', 0)
+ # In case of distribution strategy we can potentially run multiple steps
+ # at the same time, we should account for that in the `seen` calculation.
+ num_steps = logs.get('num_steps', 1)
if self.use_steps:
- self.seen += 1
+ self.seen += num_steps
else:
- self.seen += batch_size
+ self.seen += batch_size * num_steps
for k in self.params['metrics']:
if k in logs:
@@ -1068,7 +1074,7 @@ class TensorBoard(Callback):
logs = logs or {}
batch_logs = {('batch_' + k): v
for k, v in logs.items()
- if k not in ['batch', 'size']}
+ if k not in ['batch', 'size', 'num_steps']}
self._write_custom_summaries(self._total_batches_seen, batch_logs)
self._total_batches_seen += 1
@@ -1092,7 +1098,7 @@ class TensorBoard(Callback):
# batch number as Tensorboard summaries
logs = {('epoch_' + k): v
for k, v in logs.items()
- if k not in ['batch', 'size']}
+ if k not in ['batch', 'size', 'num_steps']}
self._write_custom_summaries(epoch, logs)
# pop the histogram summary op after each epoch
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index b6fae19823..467bc4cdc4 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -30,6 +30,7 @@ import numpy as np
from tensorflow.core.framework import summary_pb2
from tensorflow.python import keras
+from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
@@ -1222,6 +1223,45 @@ class KerasCallbacksTest(test.TestCase):
callbacks=cbks,
epochs=1)
+ def test_fit_generator_with_callback(self):
+
+ class TestCallback(keras.callbacks.Callback):
+ def set_model(self, model):
+ # Check the model operations for the optimizer operations that
+ # the _make_train_function adds under a named scope for the
+ # optimizer. This ensurs the full model is populated before the
+ # set_model callback is called.
+ optimizer_name_scope = 'training/' + model.optimizer.__class__.__name__
+ graph_def = ops.get_default_graph().as_graph_def()
+ for node in graph_def.node:
+ if node.name.startswith(optimizer_name_scope):
+ return
+ raise RuntimeError('The optimizer operations are not present in the '
+ 'model graph when the Callback.set_model function '
+ 'is called')
+ np.random.seed(1337)
+
+ def generator():
+ x = np.random.randn(10, 100).astype(np.float32)
+ y = np.random.randn(10, 10).astype(np.float32)
+ while True:
+ yield x, y
+
+ with self.cached_session():
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=10, num_classes=10, input_dim=100)
+ model.compile(
+ loss='categorical_crossentropy',
+ optimizer='sgd',
+ metrics=['accuracy'])
+ model.fit_generator(
+ generator(),
+ steps_per_epoch=2,
+ epochs=1,
+ validation_data=generator(),
+ validation_steps=2,
+ callbacks=[TestCallback()],
+ verbose=0)
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index cb19a412a2..a75ce30d31 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import collections as collections_lib
import enum # pylint: disable=g-bad-import-order
+import functools
import inspect # Necessary supplement to tf_inspect to deal with variadic args.
import numpy as np
@@ -160,9 +161,13 @@ class Layer(checkpointable.CheckpointableBase):
self._trainable_weights = []
self._non_trainable_weights = []
self._updates = []
- # When executing eagerly, _losses is a list of zero-argument lambdas which
- # return tensors. When using graph execution, _losses is a list of ops.
+ # A list of zero-argument lambdas which return Tensors, used for variable
+ # regularizers.
+ self._callable_losses = []
+ # A list of Tensors containing activity regularizers and losses manually
+ # added through `add_loss`. Empty when executing eagerly.
self._losses = []
+ self._in_call = False # Flag for error checking in add_loss
self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
self._call_fn_args = function_utils.fn_args(self.call)
self._compute_previous_mask = ('mask' in self._call_fn_args or
@@ -359,20 +364,20 @@ class Layer(checkpointable.CheckpointableBase):
def losses(self):
"""Losses which are associated with this `Layer`.
- Note that when executing eagerly, getting this property evaluates
- regularizers. When using graph execution, variable regularization ops have
- already been created and are simply returned here.
+ Variable regularization tensors are created when this property is accessed,
+ so it is eager safe: accessing `losses` under a `tf.GradientTape` will
+ propagate gradients back to the corresponding variables.
Returns:
A list of tensors.
"""
- if context.executing_eagerly():
- # _losses may only contain variable regularization losses when executing
- # eagerly, and they have been saved as lambdas to be executed when
- # requested.
- return [regularizer() for regularizer in self._losses]
- else:
- return self._losses
+ collected_losses = []
+ collected_losses.extend(self._losses)
+ for regularizer in self._callable_losses:
+ loss_tensor = regularizer()
+ if loss_tensor is not None:
+ collected_losses.append(loss_tensor)
+ return collected_losses
@doc_controls.for_subclass_implementers
def add_loss(self, losses, inputs=None):
@@ -393,7 +398,9 @@ class Layer(checkpointable.CheckpointableBase):
from `Layer.call()`).
Arguments:
- losses: Loss tensor, or list/tuple of tensors.
+ losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses
+ may also be zero-argument callables which create a loss tensor. Only
+ callable losses are supported when executing eagerly.
inputs: If anything other than None is passed, it signals the losses
are conditional on some of the layer's inputs,
and thus they should only be run where these inputs are available.
@@ -403,29 +410,45 @@ class Layer(checkpointable.CheckpointableBase):
(e.g. weight regularization losses).
Raises:
- RuntimeError: If called in Eager mode.
+ RuntimeError: If called in Eager mode with a `Tensor` rather than a
+ callable, or if `inputs` is not None.
"""
- if context.executing_eagerly():
- # TODO(fchollet): it should be possible (and highly desirable) to support
- # `add_loss` in eager mode. This allows great convenience and flexibility
- # in defining custom losses on the fly (e.g. in VAEs).
- # Simply appending the loss value to `self._losses`
- # is the correct behavior.
- # The only caveat is that we need to force the user to only call
- # `add_loss` from inside a model or Layer's `call` method
- # (otherwise the loss computation cannot be backproped through).
- raise RuntimeError('Layer.add_loss not supported in Eager mode.')
-
+ executing_eagerly = context.executing_eagerly()
+ if executing_eagerly:
+ if inputs is not None:
+ raise RuntimeError(
+ 'Activity regularization (via the "inputs" argument to '
+ 'Layer.add_loss) is not supported when executing eagerly. Consider '
+ 'returning activity regularization losses from a Model\'s call() '
+ 'method.')
+ if getattr(self, '_in_call', False):
+ # TODO(psv): Support activity regularization and a way to reset losses.
+ raise RuntimeError(
+ 'Adding losses inside a Layer\'s call() method is not currently '
+ 'supported when executing eagerly. Please file a feature request '
+ 'if you need this limitation lifted.')
losses = generic_utils.to_list(losses)
- losses = [ops.convert_to_tensor(loss, dtype=backend.floatx())
- if not tensor_util.is_tensor(loss) else loss for loss in losses]
- self._losses += losses
- if inputs is None:
- for loss in losses:
- loss._unconditional_loss = True # pylint: disable=protected-access
- else:
- for loss in losses:
- loss._unconditional_loss = False # pylint: disable=protected-access
+
+ def _tag_unconditional(loss):
+ if callable(loss):
+ loss = loss()
+ if loss is None:
+ return None # Will be filtered out when computing the .losses property
+ if not tensor_util.is_tensor(loss):
+ loss = ops.convert_to_tensor(loss, dtype=backend.floatx())
+ loss._unconditional_loss = (inputs is None) # pylint: disable=protected-access
+ return loss
+
+ for loss in losses:
+ if callable(loss):
+ self._callable_losses.append(
+ functools.partial(_tag_unconditional, loss))
+ else:
+ if executing_eagerly:
+ raise RuntimeError(
+ 'Layer.add_loss only supported for zero-argument lambdas when '
+ 'executing eagerly.')
+ self._losses.append(_tag_unconditional(loss))
def get_losses_for(self, inputs):
"""Retrieves losses relevant to a specific set of inputs.
@@ -599,56 +622,20 @@ class Layer(checkpointable.CheckpointableBase):
return variable
def _handle_weight_regularization(self, name, variable, regularizer):
- # `init_graph` should point to the graph in which variable initialization
- # will occur; it should be None if and only if initialization will take
- # place in the eager context.
- init_graph = None
- if not context.executing_eagerly():
- default_graph = ops.get_default_graph()
- if default_graph.building_function:
- with ops.init_scope():
- # Retrieve the variables from the graph into which variables
- # will be lifted; if initialization ops will be lifted into
- # the eager context, then there is nothing to retrieve, since variable
- # collections are not supported when eager execution is enabled.
- if not context.executing_eagerly():
- init_graph = ops.get_default_graph()
- else:
- # Initialization ops will not be lifted out of the default graph.
- init_graph = default_graph
-
- if init_graph is not None: # pylint: disable=protected-access
- # The variable was created and initialized in a graph.
- if regularizer:
- if isinstance(variable, tf_variables.PartitionedVariable):
- for v in variable:
- with ops.colocate_with(v.op):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(v)
- if regularization is not None:
- self.add_loss(regularization)
- else:
- with ops.colocate_with(variable.op):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(variable)
- if regularization is not None:
- self.add_loss(regularization)
- elif regularizer: # initialization took place in an eager context
- if isinstance(variable, tf_variables.PartitionedVariable):
- raise RuntimeError(
- 'Partitioned variable regularization is not yet '
- 'supported when executing eagerly. File a feature request'
- 'if this is important to you.')
- # Save a zero-argument lambda which runs the regularizer on the
- # variable, to be executed when `Layer.losses` is requested.
- # This makes losses responsive to variable updates when executing
- # eagerly.
- #
- # TODO(akshayka): Do the same for graphs as well, so that losses
- # collected in a while_loop can be run outside its control flow
- # context and so that losses won't be swallowed up by graph functions
- # (i.e., `.losses()` should always create regularizers).
- self._losses.append(lambda: regularizer(variable))
+ """Create lambdas which compute regularization losses."""
+
+ def _loss_for_variable(v):
+ """Creates a regularization loss `Tensor` for variable `v`."""
+ with ops.colocate_with(v):
+ with ops.name_scope(name + '/Regularizer'):
+ regularization = regularizer(v)
+ return regularization
+
+ if isinstance(variable, tf_variables.PartitionedVariable):
+ for v in variable:
+ self.add_loss(functools.partial(_loss_for_variable, v))
+ else:
+ self.add_loss(functools.partial(_loss_for_variable, variable))
def _handle_activity_regularization(self, inputs, outputs):
# Apply activity regularization.
@@ -766,7 +753,9 @@ class Layer(checkpointable.CheckpointableBase):
self._assert_input_compatibility(inputs)
if not in_deferred_mode:
+ self._in_call = True
outputs = self.call(inputs, *args, **kwargs)
+ self._in_call = False
if outputs is None:
raise ValueError('A layer\'s `call` method should return a Tensor '
'or a list of Tensors, not None (layer: ' +
@@ -1972,7 +1961,9 @@ def make_variable(name,
if use_resource is None:
use_resource = True
- v = tf_variables.Variable(
+ # TODO(apassos,rohanj) figure out how to remove collections from here so we
+ # can remove the V1.
+ v = tf_variables.VariableV1(
initial_value=init_val,
name=name,
trainable=trainable,
diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py
index b28df75493..39341a931b 100644
--- a/tensorflow/python/keras/engine/distributed_training_utils.py
+++ b/tensorflow/python/keras/engine/distributed_training_utils.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.client import session as session_module
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
@@ -293,12 +294,14 @@ def configure_and_create_session(distribution_strategy):
K.set_session(session)
-def validate_inputs(x, y):
+def validate_inputs(x, y, distribution_strategy):
"""Validate inputs when using DistributionStrategy.
Args:
x: Model Inputs.
y: Model Targets.
+ distribution_strategy: The DistributionStrategy with which the model is
+ compiled.
Raises:
ValueError: if input is not a Dataset or a numpy array.
@@ -319,6 +322,17 @@ def validate_inputs(x, y):
'Iterator. You must pass a Dataset object or a numpy '
'array as input.')
+ if distribution_strategy.__class__.__name__ == 'TPUStrategy':
+ for i in [x, y]:
+ if isinstance(i, dataset_ops.Dataset):
+ shapes = nest.flatten(i.output_shapes)
+ if any([not s.is_fully_defined() for s in shapes]):
+ raise ValueError(
+ 'Using TPUs currently requires fully defined shapes. Either use '
+ 'set_shape() on the input tensors or use '
+ 'dataset.batch(..., drop_remainder=True).'
+ 'Found unknown shape {} in input {}.'.format(s, i))
+
def get_input_batch_params(first_x_value, batch_size, current_strategy):
"""Calculate the number of batches and steps/steps_per_epoch.
diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py
index 148dd23be7..02d99d5d69 100644
--- a/tensorflow/python/keras/engine/saving_test.py
+++ b/tensorflow/python/keras/engine/saving_test.py
@@ -370,6 +370,13 @@ class TestWholeModelSaving(test.TestCase):
y = np.random.random((1, 3, 3))
model.train_on_batch(x, y)
new_model.train_on_batch(x, y)
+
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ eval_out = model.evaluate(x, y)
+ eval_out2 = new_model.evaluate(x, y)
+ self.assertArrayNear(eval_out, eval_out2, 0.001)
+
out = model.predict(x)
out2 = new_model.predict(x)
self.assertAllClose(out, out2, atol=1e-05)
diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py
index 061db8ee34..a0da96334b 100644
--- a/tensorflow/python/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/engine/topology_test.py
@@ -915,7 +915,7 @@ class TopologyConstructionTest(test.TestCase):
def test_constant_initializer_with_numpy(self):
- with self.test_session():
+ with self.cached_session():
initializer = keras.initializers.Constant(np.ones((3, 2)))
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,),
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index dc464c02b6..5091cac836 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -383,27 +383,31 @@ class Model(Network):
"""
# Validate that arguments passed by the user to `compile` are supported by
# DistributionStrategy.
- if distribute and not isinstance(
- optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
- raise NotImplementedError('Only TF native optimizers are supported with '
- 'DistributionStrategy.')
- if distribute and context.executing_eagerly():
- raise NotImplementedError('DistributionStrategy is not supported in '
- 'Eager mode.')
- if distribute and sample_weight_mode:
- raise NotImplementedError('sample_weight_mode is not supported with '
- 'DistributionStrategy.')
- if distribute and weighted_metrics:
- raise NotImplementedError('weighted_metrics is not supported with '
- 'DistributionStrategy.')
- if distribute and target_tensors:
- raise ValueError('target_tensors is not supported with '
- 'DistributionStrategy.')
+ if distribute:
+ if not isinstance(
+ optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
+ raise NotImplementedError(
+ 'optimizer must be an instance of '
+ 'tf.train.Optimizer, not a %s' % type(optimizer))
+ if context.executing_eagerly():
+ raise NotImplementedError('DistributionStrategy is not supported '
+ 'when eager execution is enabled.')
+ if sample_weight_mode:
+ raise NotImplementedError('sample_weight_mode is not supported with '
+ 'DistributionStrategy.')
+ if weighted_metrics:
+ raise NotImplementedError('weighted_metrics is not supported with '
+ 'DistributionStrategy.')
+ if target_tensors:
+ raise ValueError('target_tensors is not supported with '
+ 'DistributionStrategy.')
loss = loss or {}
if context.executing_eagerly() and not isinstance(
optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
- raise ValueError('Only TF native optimizers are supported in Eager mode.')
+ raise ValueError(
+ 'optimizer must be an instance of tf.train.Optimizer, not '
+ 'a %s' % type(optimizer))
self.optimizer = optimizers.get(optimizer)
# We've disabled automatic dependency tracking for this method, but do want
@@ -422,8 +426,9 @@ class Model(Network):
# Set DistributionStrategy specific parameters.
self._distribution_strategy = distribute
+ # Reset the value of grouped_model
+ self._grouped_model = None
if self._distribution_strategy is not None:
- self._grouped_model = None
distributed_training_utils.configure_and_create_session(
self._distribution_strategy)
if not self.built:
@@ -445,7 +450,8 @@ class Model(Network):
for name in self.output_names:
if name not in loss:
logging.warning(
- 'Output "' + name + '" missing from loss dictionary. We assume '
+ 'Output "' + name +
+ '" missing from loss dictionary. We assume '
'this was done on purpose. The fit and evaluate APIs will not be '
'expecting any data to be passed to "' + name + '".')
loss_functions.append(losses.get(loss.get(name)))
@@ -641,12 +647,6 @@ class Model(Network):
skip_target_indices=skip_target_indices,
sample_weights=self.sample_weights)
- # If using distribution strategy and stateful_metrics, raise an error
- # since we currently don't support stateful metrics.
- if self._distribution_strategy is not None and self.stateful_metric_names:
- raise NotImplementedError('Stateful metrics are not supported with '
- 'DistributionStrategy.')
-
# Prepare gradient updates and state updates.
self.total_loss = total_loss
@@ -851,7 +851,8 @@ class Model(Network):
# able to clone a Dataset on multiple workers we can remove this lambda.
result = self._distribution_strategy.distribute_dataset(lambda: x)
iterator = result.make_initializable_iterator()
- K.get_session().run(iterator.initializer)
+ with self._distribution_strategy.scope():
+ K.get_session().run(iterator.initializer)
training_utils.validate_iterator_input(x, y, sample_weight,
validation_split)
@@ -1515,7 +1516,8 @@ class Model(Network):
if self._distribution_strategy:
distributed_training_utils.validate_callbacks(callbacks)
- distributed_training_utils.validate_inputs(x, y)
+ distributed_training_utils.validate_inputs(
+ x, y, self._distribution_strategy)
first_x_value = nest.flatten(x)[0]
if not steps_per_epoch and isinstance(first_x_value, np.ndarray):
@@ -1557,7 +1559,8 @@ class Model(Network):
# Validate and standardize validation data.
if self._distribution_strategy:
- distributed_training_utils.validate_inputs(val_x, val_y)
+ distributed_training_utils.validate_inputs(
+ val_x, val_y, self._distribution_strategy)
first_valx_value = nest.flatten(val_x)[0]
if not validation_steps and isinstance(first_valx_value, np.ndarray):
validation_steps = distributed_training_utils.get_input_batch_params(
@@ -1731,7 +1734,8 @@ class Model(Network):
# Validate and standardize user data.
if self._distribution_strategy:
- distributed_training_utils.validate_inputs(x, y)
+ distributed_training_utils.validate_inputs(
+ x, y, self._distribution_strategy)
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray) and not steps:
steps = distributed_training_utils.get_input_batch_params(
@@ -1846,7 +1850,8 @@ class Model(Network):
# `MirroredStrategy`.
if hasattr(self._distribution_strategy, '_prefetch_on_device'):
self._distribution_strategy._prefetch_on_device = False # pylint: disable=protected-access
- distributed_training_utils.validate_inputs(x, None)
+ distributed_training_utils.validate_inputs(
+ x, None, self._distribution_strategy)
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray) and not steps:
steps = distributed_training_utils.get_input_batch_params(
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 53291c3956..a6470458d2 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -20,11 +20,13 @@ from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import errors
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import distributed_training_utils
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope
@@ -110,96 +112,99 @@ def fit_loop(
dataset_targets = distributed_training_utils.flatten_perdevice_values(
current_strategy, targets)
- # Create a train function that is composed of all the parameters above.
- distributed_train_function = K.Function(
- all_inputs, all_outputs,
- updates=all_updates,
- name='distributed_train_function',
- **all_session_args)
-
- # We need to set sample_weights to None since there are sample weight
- # placeholders that are created with default values.
- sample_weights = [None for _ in range(len(model.outputs) *
- current_strategy.num_towers)]
- if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
- ins = dataset_inputs + dataset_targets + sample_weights + [1]
- else:
- ins = dataset_inputs + dataset_targets
+ # Create a train function that is composed of all the parameters above.
+ distributed_train_function = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_train_function',
+ **all_session_args)
- do_validation = False
- if validation_steps:
- do_validation = True
+ # We need to set sample_weights to None since there are sample weight
+ # placeholders that are created with default values.
+ sample_weights = [None for _ in range(len(model.outputs) *
+ current_strategy.num_towers)]
+ if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = dataset_inputs + dataset_targets + sample_weights + [1]
+ else:
+ ins = dataset_inputs + dataset_targets
- # Copy the weights from the original model to each of the replicated models.
- orig_model_weights = model.get_weights()
- with current_strategy.scope():
+ do_validation = False
+ if validation_steps:
+ do_validation = True
+
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
- callbacks = cbks.configure_callbacks(
- callbacks,
- model,
- do_validation=do_validation,
- val_inputs=None,
- val_targets=None,
- epochs=epochs,
- steps_per_epoch=steps_per_epoch,
- verbose=verbose)
- out_labels = model.metrics_names or []
- callbacks.on_train_begin()
-
- assert steps_per_epoch is not None
-
- for epoch in range(initial_epoch, epochs):
- callbacks.on_epoch_begin(epoch)
- epoch_logs = {}
- for step_index in range(steps_per_epoch):
- batch_logs = {'batch': step_index, 'size': 1}
- callbacks.on_batch_begin(step_index, batch_logs)
- try:
- outs = distributed_train_function(ins)
- except errors.OutOfRangeError:
- logging.warning('Your dataset iterator ran out of data; '
- 'interrupting training. Make sure that your dataset '
- 'can generate at least `steps_per_epoch * epochs` '
- 'batches (in this case, %d batches).' %
- steps_per_epoch * epochs)
- break
-
- if not isinstance(outs, list):
- outs = [outs]
-
- outs = _aggregate_metrics_across_towers(
- current_strategy.num_towers, out_labels, outs)
- for l, o in zip(out_labels, outs):
- batch_logs[l] = o
- callbacks.on_batch_end(step_index, batch_logs)
+ callbacks = cbks.configure_callbacks(
+ callbacks,
+ model,
+ do_validation=do_validation,
+ val_inputs=None,
+ val_targets=None,
+ epochs=epochs,
+ steps_per_epoch=steps_per_epoch,
+ verbose=verbose)
+ out_labels = model.metrics_names or []
+ callbacks.on_train_begin()
+
+ assert steps_per_epoch is not None
+
+ for epoch in range(initial_epoch, epochs):
+ # Reset stateful metrics
+ for m in model.stateful_metric_functions:
+ m.reset_states()
+ callbacks.on_epoch_begin(epoch)
+ epoch_logs = {}
+ for step_index in range(steps_per_epoch):
+ batch_logs = {'batch': step_index, 'size': 1}
+ callbacks.on_batch_begin(step_index, batch_logs)
+ try:
+ outs = distributed_train_function(ins)
+ except errors.OutOfRangeError:
+ logging.warning('Your dataset iterator ran out of data; '
+ 'interrupting training. Make sure that your dataset '
+ 'can generate at least `steps_per_epoch * epochs` '
+ 'batches (in this case, %d batches).' %
+ steps_per_epoch * epochs)
+ break
+
+ if not isinstance(outs, list):
+ outs = [outs]
+
+ outs = _aggregate_metrics_across_towers(current_strategy.num_towers,
+ out_labels,
+ model.stateful_metric_names,
+ outs)
+ for l, o in zip(out_labels, outs):
+ batch_logs[l] = o
+ callbacks.on_batch_end(step_index, batch_logs)
+ if callbacks.model.stop_training:
+ break
+ if do_validation:
+ val_outs = test_loop(
+ model,
+ val_iterator,
+ steps=validation_steps,
+ verbose=0)
+ if not isinstance(val_outs, list):
+ val_outs = [val_outs]
+ # Same labels assumed.
+ for l, o in zip(out_labels, val_outs):
+ epoch_logs['val_' + l] = o
+
+ callbacks.on_epoch_end(epoch, epoch_logs)
if callbacks.model.stop_training:
break
- if do_validation:
- val_outs = test_loop(
- model,
- val_iterator,
- steps=validation_steps,
- verbose=0)
- if not isinstance(val_outs, list):
- val_outs = [val_outs]
- # Same labels assumed.
- for l, o in zip(out_labels, val_outs):
- epoch_logs['val_' + l] = o
-
- callbacks.on_epoch_end(epoch, epoch_logs)
- if callbacks.model.stop_training:
- break
- callbacks.on_train_end()
+ callbacks.on_train_end()
- # Copy the weights back from the replicated model to the original model.
- with current_strategy.scope():
+ # Copy the weights back from the replicated model to the original model.
updated_weights = current_strategy.unwrap(
model._grouped_model)[0].get_weights()
model.set_weights(updated_weights)
- return model.history
+ return model.history
def _experimental_fit_loop(
@@ -232,8 +237,6 @@ def _experimental_fit_loop(
"""
current_strategy = model._distribution_strategy
- # TODO(priyag): Add validation that shapes are fully defined for TPU case.
-
K.get_session().run(current_strategy.initialize())
def _per_device_train_function(model):
@@ -292,11 +295,16 @@ def _experimental_fit_loop(
for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
+ if steps_per_epoch is None:
+ raise ValueError('steps_per_epoch should be specified in the fit call.')
+ steps_per_run_var = K.variable(
+ value=min(steps_per_epoch, current_strategy.steps_per_run),
+ dtype='int32',
+ name='steps_per_run_var')
+
with current_strategy.scope():
- # TODO(priyag, sourabhbajaj): Adjust steps_per_run appropriately based on
- # steps_per_epoch and number of epochs.
ctx = current_strategy.run_steps_on_dataset(
- step_fn, iterator, iterations=current_strategy.steps_per_run,
+ step_fn, iterator, iterations=steps_per_run_var,
initial_loop_values=initial_loop_values)
train_op = ctx.run_op
@@ -308,14 +316,6 @@ def _experimental_fit_loop(
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
-
- assert steps_per_epoch is not None
-
- # TODO(sourabhbajaj): Convert this into a proper validation function
- if callbacks:
- raise NotImplementedError(
- 'Callbacks are not supported with TPUStrategy right now.')
-
callbacks = cbks.configure_callbacks(
callbacks,
model,
@@ -326,17 +326,26 @@ def _experimental_fit_loop(
steps_per_epoch=steps_per_epoch,
verbose=verbose)
# TODO(priyag, sourabhbajaj): Add callbacks support for per step callback
- # TODO(priyag, sourabhbajaj): Fix the number of steps run with steps_per_run
# TODO(priyag, sourabhbajaj): Add validation.
+
+ # Calculate the steps each time on the device.
+ steps_to_run = [current_strategy.steps_per_run] * (
+ steps_per_epoch // current_strategy.steps_per_run)
+ if steps_per_epoch % current_strategy.steps_per_run:
+ steps_to_run.append(steps_per_epoch % current_strategy.steps_per_run)
+
callbacks.on_train_begin()
for epoch in range(initial_epoch, epochs):
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
- for step_index in range(0, steps_per_epoch, current_strategy.steps_per_run):
- # TODO(sourabhbajaj): Replace size with a combination of steps_per_run
- # and batch_size
- batch_logs = {'batch': step_index, 'size': 1}
+ step_index = 0
+ prev_step_count = None
+ for step_count in steps_to_run:
+ batch_logs = {'batch': step_index, 'size': 1, 'num_steps': step_count}
callbacks.on_batch_begin(step_index, batch_logs)
+ if prev_step_count is None or step_count != prev_step_count:
+ steps_per_run_var.load(step_count, K.get_session())
+ prev_step_count = step_count
try:
_, outputs = K.get_session().run([train_op, output_tensors])
except errors.OutOfRangeError:
@@ -349,6 +358,7 @@ def _experimental_fit_loop(
batch_logs.update(outputs)
callbacks.on_batch_end(step_index, batch_logs)
+ step_index = step_index + step_count
if callbacks.model.stop_training:
break
@@ -416,54 +426,65 @@ def test_loop(model, iterator, verbose=0, steps=None):
dataset_targets = distributed_training_utils.flatten_perdevice_values(
current_strategy, targets)
- distributed_test_function = K.Function(
- all_inputs, all_outputs,
- updates=all_updates,
- name='distributed_test_function',
- **all_session_args)
-
- # We need to set sample_weights to None since there are sample weight
- # placeholders that are created with default values.
- sample_weights = [None for _ in range(len(model.outputs) *
- current_strategy.num_towers)]
- if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
- ins = dataset_inputs + dataset_targets + sample_weights + [0]
- else:
- ins = dataset_inputs + dataset_targets
+ distributed_test_function = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_test_function',
+ **all_session_args)
- outs = []
- if verbose == 1:
- progbar = Progbar(target=steps)
+ # We need to set sample_weights to None since there are sample weight
+ # placeholders that are created with default values.
+ sample_weights = [None for _ in range(len(model.outputs) *
+ current_strategy.num_towers)]
+ if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = dataset_inputs + dataset_targets + sample_weights + [0]
+ else:
+ ins = dataset_inputs + dataset_targets
- # Copy the weights from the original model to each of the replicated models.
- orig_model_weights = model.get_weights()
- with current_strategy.scope():
+ for m in model.stateful_metric_functions:
+ m.reset_states()
+ stateful_metric_indices = [
+ i for i, name in enumerate(model.metrics_names)
+ if str(name) in model.stateful_metric_names
+ ]
+
+ outs = []
+ if verbose == 1:
+ progbar = Progbar(target=steps)
+
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
- assert steps is not None
- for step in range(steps):
- batch_outs = distributed_test_function(ins)
- batch_outs = _aggregate_metrics_across_towers(
- current_strategy.num_towers, model.metrics_names, batch_outs)
- if isinstance(batch_outs, list):
- if step == 0:
- outs = [0.] * len(batch_outs)
- for i, batch_out in enumerate(batch_outs):
- outs[i] += batch_out
- else:
- if step == 0:
- outs.append(0.)
- outs[0] += batch_outs
- if verbose >= 1:
- progbar.update(step + 1)
- for i in range(len(outs)):
- outs[i] /= steps
+ assert steps is not None
+ for step in range(steps):
+ batch_outs = distributed_test_function(ins)
+ batch_outs = _aggregate_metrics_across_towers(
+ current_strategy.num_towers, model.metrics_names,
+ model.stateful_metric_names, batch_outs)
+ if isinstance(batch_outs, list):
+ if step == 0:
+ outs = [0.] * len(batch_outs)
+ for i, batch_out in enumerate(batch_outs):
+ if i in stateful_metric_indices:
+ outs[i] = batch_out
+ else:
+ outs[i] += batch_out
+ else:
+ if step == 0:
+ outs.append(0.)
+ outs[0] += batch_outs
+ if verbose >= 1:
+ progbar.update(step + 1)
+ for i in range(len(outs)):
+ if i not in stateful_metric_indices:
+ outs[i] /= steps
- if len(outs) == 1:
- return outs[0]
- return outs
+ if len(outs) == 1:
+ return outs[0]
+ return outs
def _experimental_test_loop(model, iterator, verbose=0, steps=None):
@@ -624,51 +645,50 @@ def predict_loop(model, iterator, verbose=0, steps=None):
dataset_inputs = distributed_training_utils.flatten_perdevice_values(
current_strategy, inputs)
- distributed_predict_function = K.Function(
- all_inputs, all_outputs,
- updates=all_updates,
- name='distributed_predict_function',
- **all_session_args)
+ distributed_predict_function = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_predict_function',
+ **all_session_args)
- if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
- ins = dataset_inputs + [0]
- else:
- ins = dataset_inputs
+ if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = dataset_inputs + [0]
+ else:
+ ins = dataset_inputs
- if verbose == 1:
- progbar = Progbar(target=steps)
+ if verbose == 1:
+ progbar = Progbar(target=steps)
- # Copy the weights from the original model to each of the replicated models.
- orig_model_weights = model.get_weights()
- with current_strategy.scope():
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
- if steps is not None:
- # Since we do not know how many samples we will see, we cannot pre-allocate
- # the returned Numpy arrays. Instead, we store one array per batch seen
- # and concatenate them upon returning.
- unconcatenated_outs = []
- for step in range(steps):
- batch_outs = distributed_predict_function(ins)
- if not isinstance(batch_outs, list):
- batch_outs = [batch_outs]
- if step == 0:
- for _ in batch_outs:
- unconcatenated_outs.append([])
- # TODO(anjalisridhar): Should combine the outputs from multiple towers
- # correctly here.
- for i, batch_out in enumerate(batch_outs):
- unconcatenated_outs[i].append(batch_out)
- if verbose >= 1:
- progbar.update(step + 1)
- if len(unconcatenated_outs) == 1:
- return np.concatenate(unconcatenated_outs[0], axis=0)
- return [
- np.concatenate(unconcatenated_outs[i], axis=0)
- for i in range(len(unconcatenated_outs))
- ]
+ if steps is not None:
+ # Since we do not know how many samples we will see, we cannot
+ # pre-allocate the returned Numpy arrays. Instead, we store one array per
+ # batch seen and concatenate them upon returning.
+ unconcatenated_outs = []
+ for step in range(steps):
+ batch_outs = distributed_predict_function(ins)
+ if not isinstance(batch_outs, list):
+ batch_outs = [batch_outs]
+ if step == 0:
+ for _ in batch_outs:
+ unconcatenated_outs.append([])
+ # TODO(anjalisridhar): Should combine the outputs from multiple towers
+ # correctly here.
+ for i, batch_out in enumerate(batch_outs):
+ unconcatenated_outs[i].append(batch_out)
+ if verbose >= 1:
+ progbar.update(step + 1)
+ if len(unconcatenated_outs) == 1:
+ return np.concatenate(unconcatenated_outs[0], axis=0)
+ return [
+ np.concatenate(unconcatenated_outs[i], axis=0)
+ for i in range(len(unconcatenated_outs))
+ ]
def _experimental_predict_loop(model, iterator, verbose=0, steps=None):
@@ -742,8 +762,9 @@ def _experimental_predict_loop(model, iterator, verbose=0, steps=None):
for name, tensor in zip(model.output_names, model.outputs):
# TODO(priyag): This is a workaround as we do not know the batch dimension
# of the model's output at this point.
- tensor.shape.dims = [batch_dimension] + tensor.shape.dims[1:]
- initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
+ shape = tensor_shape.TensorShape(tensor.shape.dims)
+ shape.dims = [batch_dimension] + shape.dims[1:]
+ initial_loop_values[name] = array_ops.zeros(shape, tensor.dtype)
with current_strategy.scope():
# TODO(priyag, sourabhbajaj): Support steps_per_run if/when we add outfeed.
@@ -809,10 +830,10 @@ def _clone_and_build_model(model, inputs=None, targets=None):
cloned_model.compile(
optimizer,
model.loss,
- metrics=model.metrics,
+ metrics=metrics_module.clone_metrics(model.metrics),
loss_weights=model.loss_weights,
sample_weight_mode=model.sample_weight_mode,
- weighted_metrics=model.weighted_metrics,
+ weighted_metrics=metrics_module.clone_metrics(model.weighted_metrics),
target_tensors=targets)
return cloned_model
@@ -827,8 +848,9 @@ def clone_model_on_towers(
model._make_callback_model()
-def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
- """Aggregate metrics values across all towers.
+def _aggregate_metrics_across_towers(num_devices, out_labels,
+ stateful_metric_names, outs):
+ """Aggregates stateless metrics values across towers.
When using `MirroredStrategy`, the number of towers is equal to the
number of devices over which training is distributed. This may not always be
@@ -837,6 +859,7 @@ def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
Args:
num_devices: Number of devices over which the model is being distributed.
out_labels: The list of metric names passed to `compile`.
+ stateful_metric_names: List of stateful metric names on the model.
outs: The output from all the towers.
Returns:
@@ -851,10 +874,16 @@ def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
# Each label in `out_labels` corresponds to one set of metrics. The
# number of metric values corresponds to the number of devices. We
# currently take the mean of the values.
- for _ in out_labels[1:]:
- m = np.mean(outs[current_index:current_index + num_devices])
- merged_output.append(m)
- current_index += num_devices
+ for metric_name in out_labels[1:]:
+ if metric_name in stateful_metric_names:
+ # For stateful metrics, we get one aggregated result value.
+ merged_output.append(outs[current_index])
+ current_index += 1
+ else:
+ m = np.mean(outs[current_index:current_index + num_devices])
+ merged_output.append(m)
+ current_index += num_devices
+
return merged_output
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index 939a7f2356..fb71bf2596 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -186,7 +186,7 @@ def iterator_fit_loop(model,
# make sure either x,y or x,y,sample_weights is provided
if (not isinstance(inputs.output_shapes, (list, tuple)) or
len(inputs.output_shapes) not in (2, 3)):
- raise ValueError('Please provide either inputs and targets'
+ raise ValueError('Please provide either inputs and targets '
'or inputs, targets, and sample_weights')
for step_index in range(steps_per_epoch):
diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py
index db7ccb181f..1f5176c4d7 100644
--- a/tensorflow/python/keras/engine/training_eager_test.py
+++ b/tensorflow/python/keras/engine/training_eager_test.py
@@ -192,6 +192,20 @@ class CorrectnessTest(test.TestCase):
history = model.fit(iterator, epochs=1, steps_per_epoch=10)
self.assertEqual(np.around(history.history['loss'][-1], decimals=4), 0.6173)
+ def test_no_loss_in_call(self):
+
+ class HasLoss(keras.layers.Layer):
+
+ def call(self, x):
+ self.add_loss(x)
+ return x
+
+ layer = HasLoss()
+ with self.assertRaises(RuntimeError):
+ layer(1.)
+
+ with ops.Graph().as_default():
+ layer(1.)
if __name__ == '__main__':
ops.enable_eager_execution()
diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py
index 413c1f4fba..2e074699da 100644
--- a/tensorflow/python/keras/engine/training_generator.py
+++ b/tensorflow/python/keras/engine/training_generator.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.eager import context
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer
from tensorflow.python.keras.utils.data_utils import OrderedEnqueuer
@@ -48,6 +49,10 @@ def fit_generator(model,
epoch = initial_epoch
do_validation = bool(validation_data)
+ if not context.executing_eagerly():
+ model._make_train_function()
+ if do_validation:
+ model._make_test_function()
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
@@ -233,6 +238,9 @@ def evaluate_generator(model,
use_multiprocessing=False,
verbose=0):
"""See docstring for `Model.evaluate_generator`."""
+ if not context.executing_eagerly():
+ model._make_test_function()
+
if hasattr(model, 'metrics'):
for m in model.stateful_metric_functions:
m.reset_states()
@@ -342,6 +350,9 @@ def predict_generator(model,
use_multiprocessing=False,
verbose=0):
"""See docstring for `Model.predict_generator`."""
+ if not context.executing_eagerly():
+ model._make_test_function()
+
steps_done = 0
wait_time = 0.01
all_outs = []
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 30be4131a4..54ad74c08b 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -27,6 +27,7 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
+from tensorflow.python.eager import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util as tf_test_util
@@ -2427,6 +2428,17 @@ class TestTrainingWithMetrics(test.TestCase):
scores = model.train_on_batch(x, y, sample_weight=w)
self.assertArrayNear(scores, [0.2, 0.8, 0.8], 0.1)
+ def test_losses_in_defun(self):
+ with context.eager_mode():
+ layer = keras.layers.Dense(1, kernel_regularizer='l1')
+ layer(array_ops.ones([1, 10]))
+
+ @function.defun
+ def get_losses():
+ return layer.losses
+
+ self.assertAllEqual(self.evaluate(layer.losses),
+ self.evaluate(get_losses()))
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py
index 4ab786a184..a2385dfdbb 100644
--- a/tensorflow/python/keras/layers/advanced_activations.py
+++ b/tensorflow/python/keras/layers/advanced_activations.py
@@ -314,7 +314,9 @@ class ReLU(Layer):
'cannot be negative value: ' + str(negative_slope))
self.support_masking = True
- self.max_value = K.cast_to_floatx(max_value)
+ if max_value is not None:
+ max_value = K.cast_to_floatx(max_value)
+ self.max_value = max_value
self.negative_slope = K.cast_to_floatx(negative_slope)
self.threshold = K.cast_to_floatx(threshold)
diff --git a/tensorflow/python/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/layers/advanced_activations_test.py
index b020b6e730..c41087be0a 100644
--- a/tensorflow/python/keras/layers/advanced_activations_test.py
+++ b/tensorflow/python/keras/layers/advanced_activations_test.py
@@ -67,6 +67,14 @@ class AdvancedActivationsTest(test.TestCase):
testing_utils.layer_test(keras.layers.ReLU,
kwargs={'max_value': 10},
input_shape=(2, 3, 4))
+ x = keras.backend.ones((3, 4))
+ # Test that we use `leaky_relu` when appropriate in graph mode.
+ self.assertTrue(
+ 'LeakyRelu' in keras.layers.ReLU(negative_slope=0.2)(x).name)
+ # Test that we use `relu` when appropriate in graph mode.
+ self.assertTrue('Relu' in keras.layers.ReLU()(x).name)
+ # Test that we use `relu6` when appropriate in graph mode.
+ self.assertTrue('Relu6' in keras.layers.ReLU(max_value=6)(x).name)
def test_relu_with_invalid_arg(self):
with self.assertRaisesRegexp(
diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py
index 4032202986..efa21955e6 100644
--- a/tensorflow/python/keras/layers/core.py
+++ b/tensorflow/python/keras/layers/core.py
@@ -671,22 +671,34 @@ class Lambda(Layer):
if mask is not None:
self.supports_masking = True
self.mask = mask
- if output_shape is None:
- self._output_shape = None
- elif isinstance(output_shape, (tuple, list)):
- self._output_shape = tuple(output_shape)
- else:
- if not callable(output_shape):
- raise TypeError('In Lambda, `output_shape` '
- 'must be a list, a tuple, or a function.')
- self._output_shape = output_shape
+ if (output_shape is not None and not isinstance(output_shape,
+ (tuple, list)) and
+ not callable(output_shape)):
+ raise TypeError('In Lambda, `output_shape` '
+ 'must be a list, a tuple, or a function.')
+ # Convert a list representing a single shape into a tuple.
+ if (isinstance(output_shape, list) and isinstance(output_shape[0],
+ (int, type(None)))):
+ output_shape = tuple(output_shape)
+ self._output_shape = output_shape
@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
if self._output_shape is None:
if context.executing_eagerly():
- raise NotImplementedError
- x = K.placeholder(shape=input_shape)
+ # Make use of existing autocomputation for Eager mode but provide
+ # Lambda-specific error message.
+ try:
+ return super(Lambda, self).compute_output_shape(input_shape)
+ except NotImplementedError:
+ raise NotImplementedError('We could not automatically infer '
+ 'the static shape of the Lambda\'s output.'
+ ' Please specify the `output_shape` for'
+ ' this Lambda.')
+ if isinstance(input_shape, list):
+ x = [K.placeholder(shape=shape) for shape in input_shape]
+ else:
+ x = K.placeholder(shape=input_shape)
x = self.call(x)
if isinstance(x, list):
return [tensor_shape.TensorShape(K.int_shape(x_elem)) for x_elem in x]
@@ -697,16 +709,27 @@ class Lambda(Layer):
num_samples = input_shape[0][0]
else:
num_samples = input_shape[0] if input_shape else None
- return tensor_shape.TensorShape((num_samples,) +
- tuple(self._output_shape))
+ # List here represents multiple outputs.
+ if isinstance(self._output_shape, list):
+ return [
+ tensor_shape.TensorShape((num_samples,) + tuple(single_shape))
+ for single_shape in self._output_shape
+ ]
+ return tensor_shape.TensorShape((num_samples,) + self._output_shape)
else:
shape = self._output_shape(input_shape)
if not isinstance(shape, (list, tuple)):
raise ValueError(
'`output_shape` function must return a tuple or a list of tuples.')
+ # List here can represent multiple outputs or single output.
if isinstance(shape, list):
- if isinstance(shape[0], int) or shape[0] is None:
+ # Convert list representing single output into a tuple.
+ if isinstance(shape[0], (int, type(None))):
shape = tuple(shape)
+ else:
+ return [
+ tensor_shape.TensorShape(single_shape) for single_shape in shape
+ ]
return tensor_shape.TensorShape(shape)
def call(self, inputs, mask=None):
diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py
index 1df1d575b1..f0fea1f65c 100644
--- a/tensorflow/python/keras/layers/core_test.py
+++ b/tensorflow/python/keras/layers/core_test.py
@@ -252,6 +252,51 @@ class CoreLayersTest(test.TestCase):
l(keras.backend.variable(np.ones((1, 1))))
self.assertEqual('lambda', l.get_config()['output_shape_type'])
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_lambda_output_shape_autocalculate_multiple_inputs(self):
+
+ def lambda_fn(x):
+ return math_ops.matmul(x[0], x[1])
+
+ l = keras.layers.Lambda(lambda_fn)
+ output_shape = l.compute_output_shape([(10, 10), (10, 20)])
+ self.assertAllEqual((10, 20), output_shape)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_lambda_output_shape_list_multiple_outputs(self):
+
+ def lambda_fn(x):
+ return x
+
+ l = keras.layers.Lambda(lambda_fn, output_shape=[(10,), (20,)])
+ output_shape = l.compute_output_shape([(10, 10), (10, 20)])
+ self.assertAllEqual([(10, 10), (10, 20)], output_shape)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_lambda_output_shape_tuple_with_none(self):
+
+ def lambda_fn(x):
+ return x
+
+ l = keras.layers.Lambda(lambda_fn, output_shape=(None, 10))
+ output_shape = l.compute_output_shape((5, 10, 20))
+ # Dimension(None) != Dimension(None), so check
+ # str representations for equality.
+ self.assertAllEqual(('5', '?', '10'), tuple([str(s) for s in output_shape]))
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_lambda_output_shape_function_multiple_outputs(self):
+
+ def lambda_fn(x):
+ return x
+
+ def output_shape_fn(input_shape):
+ return input_shape
+
+ l = keras.layers.Lambda(lambda_fn, output_shape=output_shape_fn)
+ output_shape = l.compute_output_shape([(10, 10), (10, 20)])
+ self.assertAllEqual([(10, 10), (10, 20)], output_shape)
+
def test_lambda_config_serialization(self):
with self.cached_session():
# test serialization with output_shape and output_shape_type
diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py
index 629a9ec9a1..c6df5f2e26 100644
--- a/tensorflow/python/keras/layers/embeddings.py
+++ b/tensorflow/python/keras/layers/embeddings.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
@@ -117,12 +119,27 @@ class Embedding(Layer):
@tf_utils.shape_type_conversion
def build(self, input_shape):
- self.embeddings = self.add_weight(
- shape=(self.input_dim, self.output_dim),
- initializer=self.embeddings_initializer,
- name='embeddings',
- regularizer=self.embeddings_regularizer,
- constraint=self.embeddings_constraint)
+ # Note: most sparse optimizers do not have GPU kernels defined. When
+ # building graphs, the placement algorithm is able to place variables on CPU
+ # since it knows all kernels using the variable only exist on CPU.
+ # When eager execution is enabled, the placement decision has to be made
+ # right now. Checking for the presence of GPUs to avoid complicating the
+ # TPU codepaths which can handle sparse optimizers.
+ if context.executing_eagerly() and context.context().num_gpus():
+ with ops.device('cpu:0'):
+ self.embeddings = self.add_weight(
+ shape=(self.input_dim, self.output_dim),
+ initializer=self.embeddings_initializer,
+ name='embeddings',
+ regularizer=self.embeddings_regularizer,
+ constraint=self.embeddings_constraint)
+ else:
+ self.embeddings = self.add_weight(
+ shape=(self.input_dim, self.output_dim),
+ initializer=self.embeddings_initializer,
+ name='embeddings',
+ regularizer=self.embeddings_regularizer,
+ constraint=self.embeddings_constraint)
self.built = True
def compute_mask(self, inputs, mask=None):
diff --git a/tensorflow/python/keras/layers/embeddings_test.py b/tensorflow/python/keras/layers/embeddings_test.py
index cab176ee34..2e42e403aa 100644
--- a/tensorflow/python/keras/layers/embeddings_test.py
+++ b/tensorflow/python/keras/layers/embeddings_test.py
@@ -21,9 +21,11 @@ from __future__ import print_function
import numpy as np
from tensorflow.python import keras
+from tensorflow.python.eager import backprop
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
+from tensorflow.python.training import adagrad
class EmbeddingTest(test.TestCase):
@@ -78,6 +80,17 @@ class EmbeddingTest(test.TestCase):
outputs = keras.backend.eval(layer(inputs))
self.assertAllClose(outputs, [[[1, 1], [2, 2], [1, 1]]])
+ @tf_test_util.run_in_graph_and_eager_modes()
+ def test_eager_gpu_cpu(self):
+ l = keras.layers.Embedding(output_dim=2, input_dim=2)
+ l.build((None, 2))
+ inputs = keras.backend.constant([[0, 1, 0]], dtype='int32')
+ with backprop.GradientTape() as tape:
+ output = l(inputs)
+ gs = tape.gradient(output, l.weights)
+ opt = adagrad.AdagradOptimizer(0.1)
+ opt.apply_gradients(zip(gs, l.weights))
+ self.assertAllEqual(len(gs), 1)
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index fd3c39cf2e..f4e8419eb0 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -71,6 +71,22 @@ def check_is_tensor_or_operation(x, name):
name, x))
+def clone_metric(metric):
+ """Returns a clone of the metric if stateful, otherwise returns it as is."""
+ if isinstance(metric, Metric):
+ return metric.__class__.from_config(metric.get_config())
+ return metric
+
+
+def clone_metrics(metrics):
+ """Clones the given metric list/dict."""
+ if metrics is None:
+ return None
+ if isinstance(metrics, dict):
+ return {key: clone_metric(value) for key, value in metrics.items()}
+ return [clone_metric(metric) for metric in metrics]
+
+
def update_state_wrapper(update_state_fn):
"""Decorator to wrap metric `update_state()` with `add_update()`.
@@ -199,7 +215,6 @@ def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
# squeeze last dim of `y_pred` or `y_true` if their rank differs by 1
y_true, y_pred = confusion_matrix.remove_squeezable_dimensions(
y_true, y_pred)
- y_pred.get_shape().assert_is_compatible_with(y_true.get_shape())
if sample_weight is None:
return y_pred, y_true, None
@@ -342,19 +357,14 @@ class Metric(Layer):
# weak reference. This is to remove reference cycle that is created here.
# This is not an issue in python versions > 3.
if context.executing_eagerly():
- update_state = weakmethod(obj.update_state)
- else:
- update_state = function.defun(obj.update_state)
+ obj.update_state = weakmethod(obj.update_state)
obj.update_state = weakmethod(
- types.MethodType(update_state_wrapper(update_state), obj))
+ types.MethodType(update_state_wrapper(obj.update_state), obj))
result = weakmethod(obj.result)
obj.result = weakmethod(types.MethodType(result_wrapper(result), obj))
else:
- # Converting update_state_fn() into a graph function, so that
- # we can return a single op that performs all of the variable updates.
- defuned_update_state_fn = function.defun(obj.update_state)
obj.update_state = types.MethodType(
- update_state_wrapper(defuned_update_state_fn), obj)
+ update_state_wrapper(obj.update_state), obj)
obj.result = types.MethodType(result_wrapper(obj.result), obj)
return obj
@@ -475,6 +485,9 @@ class Mean(Metric):
Args:
values: Per-example value.
sample_weight: Optional weighting of each example. Defaults to 1.
+
+ Returns:
+ Update op.
"""
values = math_ops.cast(values, self._dtype)
if sample_weight is None:
@@ -501,8 +514,9 @@ class Mean(Metric):
values = math_ops.reduce_sum(values)
# Update state variables
- state_ops.assign_add(self.total, values)
- state_ops.assign_add(self.count, num_values)
+ update_total_op = state_ops.assign_add(self.total, values)
+ update_count_op = state_ops.assign_add(self.count, num_values)
+ return control_flow_ops.group(update_total_op, update_count_op)
def result(self):
return safe_div(self.total, self.count)
@@ -536,6 +550,9 @@ class MeanMetricWrapper(Mean):
sample_weight: Optional weighting of each example. Defaults to 1. Can be
a `Tensor` whose rank is either 0, or the same rank as `y_true`,
and must be broadcastable to `y_true`.
+
+ Returns:
+ Update op.
"""
y_true = math_ops.cast(y_true, self._dtype)
y_pred = math_ops.cast(y_pred, self._dtype)
@@ -543,7 +560,7 @@ class MeanMetricWrapper(Mean):
y_pred, y_true, sample_weight)
matches = self._fn(y_true, y_pred, **self._fn_kwargs)
- super(MeanMetricWrapper, self).update_state(
+ return super(MeanMetricWrapper, self).update_state(
matches, sample_weight=sample_weight)
def get_config(self):
@@ -600,6 +617,23 @@ class CategoricalAccuracy(MeanMetricWrapper):
categorical_accuracy, name, dtype=dtype)
+class SparseCategoricalAccuracy(MeanMetricWrapper):
+ """Calculates how often predictions matches integer labels.
+
+ This metric creates two local variables, `total` and `count` that are used to
+ compute the frequency with which `y_pred` matches `y_true`. This frequency is
+ ultimately returned as `sparse categorical accuracy`: an idempotent operation
+ that simply divides `total` by `count`.
+
+ If `sample_weight` is `None`, weights default to 1.
+ Use `sample_weight` of 0 to mask values.
+ """
+
+ def __init__(self, name='sparse_categorical_accuracy', dtype=None):
+ super(SparseCategoricalAccuracy, self).__init__(
+ sparse_categorical_accuracy, name, dtype=dtype)
+
+
@tf_export('keras.metrics.binary_accuracy')
def binary_accuracy(y_true, y_pred, threshold=0.5):
threshold = math_ops.cast(threshold, y_pred.dtype)
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index 41c5e3cccf..b04b4df257 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import saving
from tensorflow.python.keras.engine import sequential
@@ -290,7 +291,9 @@ def _in_place_subclassed_model_reset(model):
if isinstance(value, Layer):
attributes_cache[name] = value
assert value in model._layers
- elif isinstance(value, (list, tuple)) and name not in ('layers', '_layers'):
+ elif isinstance(
+ value, (list, tuple)) and name not in ('layers', '_layers',
+ 'stateful_metric_functions'):
# Handle case: list/tuple of layers (also tracked by the Network API).
if value and all(isinstance(val, Layer) for val in value):
raise ValueError('We do not support the use of list-of-layers '
@@ -466,10 +469,10 @@ def clone_and_build_model(
clone.compile(
optimizer,
model.loss,
- metrics=model.metrics,
+ metrics=metrics_module.clone_metrics(model.metrics),
loss_weights=model.loss_weights,
sample_weight_mode=model.sample_weight_mode,
- weighted_metrics=model.weighted_metrics,
+ weighted_metrics=metrics_module.clone_metrics(model.weighted_metrics),
target_tensors=target_tensors)
return clone
diff --git a/tensorflow/python/keras/optimizers_test.py b/tensorflow/python/keras/optimizers_test.py
index 8d7493462e..9664f09fff 100644
--- a/tensorflow/python/keras/optimizers_test.py
+++ b/tensorflow/python/keras/optimizers_test.py
@@ -18,10 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import gc
+import weakref
+
import numpy as np
from tensorflow.python import keras
from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
@@ -156,6 +160,19 @@ class KerasOptimizersTest(test.TestCase):
with self.assertRaises(NotImplementedError):
optimizer.from_config(None)
+ def test_optimizer_garbage_collection(self):
+ graph = ops.Graph()
+ with graph.as_default():
+ optimizer = keras.optimizers.TFOptimizer(AdamOptimizer(0.01))
+ keras.backend.track_tf_optimizer(optimizer)
+ optimizer_weak = weakref.ref(optimizer)
+ graph_weak = weakref.ref(graph)
+ del graph, optimizer
+ gc.collect()
+ # Check that the weak references are dead now.
+ self.assertIs(graph_weak(), None)
+ self.assertIs(optimizer_weak(), None)
+
@test_util.run_in_graph_and_eager_modes
def test_tfoptimizer_iterations(self):
with self.cached_session():
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils_test.py b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
index c7e94998b4..3d0351a11f 100644
--- a/tensorflow/python/keras/utils/multi_gpu_utils_test.py
+++ b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
@@ -48,7 +48,7 @@ class TestMultiGPUModel(test.TestCase):
if not check_if_compatible_devices(gpus=gpus):
return
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(hidden_dim,
input_shape=(input_dim,)))
@@ -78,7 +78,7 @@ class TestMultiGPUModel(test.TestCase):
if not check_if_compatible_devices(gpus=gpus):
return
- with self.test_session():
+ with self.cached_session():
input_a = keras.Input((input_dim_a,))
input_b = keras.Input((input_dim_b,))
a = keras.layers.Dense(hidden_dim)(input_a)
@@ -105,7 +105,7 @@ class TestMultiGPUModel(test.TestCase):
if not check_if_compatible_devices(gpus=2):
return
- with self.test_session():
+ with self.cached_session():
input_shape = (1000, 10)
model = keras.models.Sequential()
model.add(keras.layers.Dense(10,
@@ -144,7 +144,7 @@ class TestMultiGPUModel(test.TestCase):
if not check_if_compatible_devices(gpus=gpus):
return
- with self.test_session():
+ with self.cached_session():
input_shape = (num_samples,) + shape
x_train = np.random.randint(0, 255, input_shape)
y_train = np.random.randint(0, num_classes, (input_shape[0],))
@@ -186,7 +186,7 @@ class TestMultiGPUModel(test.TestCase):
if not check_if_compatible_devices(gpus=gpus):
return
- with self.test_session():
+ with self.cached_session():
inputs = keras.Input((4, 3))
init_state = keras.Input((3,))
outputs = keras.layers.SimpleRNN(
diff --git a/tensorflow/python/keras/wrappers/scikit_learn_test.py b/tensorflow/python/keras/wrappers/scikit_learn_test.py
index c322efdedf..f904290803 100644
--- a/tensorflow/python/keras/wrappers/scikit_learn_test.py
+++ b/tensorflow/python/keras/wrappers/scikit_learn_test.py
@@ -102,7 +102,7 @@ def assert_regression_works(reg):
class ScikitLearnAPIWrapperTest(test.TestCase):
def test_classify_build_fn(self):
- with self.test_session():
+ with self.cached_session():
clf = keras.wrappers.scikit_learn.KerasClassifier(
build_fn=build_fn_clf,
hidden_dim=HIDDEN_DIM,
@@ -118,7 +118,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
def __call__(self, hidden_dim):
return build_fn_clf(hidden_dim)
- with self.test_session():
+ with self.cached_session():
clf = keras.wrappers.scikit_learn.KerasClassifier(
build_fn=ClassBuildFnClf(),
hidden_dim=HIDDEN_DIM,
@@ -134,7 +134,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
def __call__(self, hidden_dim):
return build_fn_clf(hidden_dim)
- with self.test_session():
+ with self.cached_session():
clf = InheritClassBuildFnClf(
build_fn=None,
hidden_dim=HIDDEN_DIM,
@@ -144,7 +144,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
assert_classification_works(clf)
def test_regression_build_fn(self):
- with self.test_session():
+ with self.cached_session():
reg = keras.wrappers.scikit_learn.KerasRegressor(
build_fn=build_fn_reg,
hidden_dim=HIDDEN_DIM,
@@ -160,7 +160,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
def __call__(self, hidden_dim):
return build_fn_reg(hidden_dim)
- with self.test_session():
+ with self.cached_session():
reg = keras.wrappers.scikit_learn.KerasRegressor(
build_fn=ClassBuildFnReg(),
hidden_dim=HIDDEN_DIM,
@@ -176,7 +176,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
def __call__(self, hidden_dim):
return build_fn_reg(hidden_dim)
- with self.test_session():
+ with self.cached_session():
reg = InheritClassBuildFnReg(
build_fn=None,
hidden_dim=HIDDEN_DIM,
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 6bba99b9e7..65b9e04ed9 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -538,6 +538,21 @@ tf_py_test(
)
tf_py_test(
+ name = "logging_ops_logging_level_test",
+ size = "small",
+ srcs = ["logging_ops_logging_level_test.py"],
+ additional_deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:logging_ops",
+ ],
+ tags = [
+ "no_windows",
+ ],
+)
+
+tf_py_test(
name = "logging_ops_test",
size = "small",
srcs = ["logging_ops_test.py"],
@@ -961,6 +976,19 @@ tf_py_test(
)
tf_py_test(
+ name = "string_format_op_test",
+ size = "small",
+ srcs = ["string_format_op_test.py"],
+ additional_deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:math_ops",
+ ],
+)
+
+tf_py_test(
name = "string_join_op_test",
size = "small",
srcs = ["string_join_op_test.py"],
@@ -1069,6 +1097,18 @@ tf_py_test(
],
)
+tf_py_test(
+ name = "unicode_script_op_test",
+ size = "small",
+ srcs = ["unicode_script_op_test.py"],
+ additional_deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:string_ops",
+ ],
+)
+
cuda_py_test(
name = "topk_op_test",
size = "small",
@@ -1440,7 +1480,7 @@ cuda_py_test(
name = "control_flow_ops_py_test",
# TODO(b/70473603): change this back to "small" once the C API is
# permanently enabled
- size = "medium",
+ size = "large",
srcs = ["control_flow_ops_py_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -1472,6 +1512,7 @@ cuda_py_test(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
+ "//tensorflow/python:while_v2",
],
)
@@ -1635,6 +1676,18 @@ cuda_py_test(
)
cuda_py_test(
+ name = "extract_volume_patches_op_test",
+ size = "small",
+ srcs = ["extract_volume_patches_op_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ ],
+)
+
+cuda_py_test(
name = "functional_ops_test",
size = "small",
srcs = ["functional_ops_test.py"],
@@ -2799,6 +2852,46 @@ cuda_py_test(
)
cuda_py_test(
+ name = "cwise_ops_binary_test",
+ size = "medium",
+ srcs = ["cwise_ops_binary_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:math_ops_gen",
+ "//tensorflow/python:nn_grad",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:variables",
+ ],
+ shard_count = 50,
+)
+
+cuda_py_test(
+ name = "cwise_ops_unary_test",
+ size = "medium",
+ srcs = ["cwise_ops_unary_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:math_ops_gen",
+ "//tensorflow/python:nn_grad",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:variables",
+ ],
+ shard_count = 50,
+)
+
+cuda_py_test(
name = "embedding_ops_test",
size = "medium",
srcs = ["embedding_ops_test.py"],
@@ -3164,3 +3257,25 @@ tf_py_test(
grpc_enabled = True,
tags = ["no_gpu"], # TODO(b/111656070)
)
+
+cuda_py_test(
+ name = "while_v2_test",
+ size = "medium",
+ srcs = ["while_v2_test.py"],
+ additional_deps = [
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:gradients_impl",
+ "//tensorflow/python:list_ops",
+ "//tensorflow/python:tf_optimizer",
+ "//tensorflow/python:while_v2",
+ ],
+ grpc_enabled = True,
+)
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 573bb8614f..c5547b19be 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -1001,14 +1001,14 @@ class SliceAssignTest(test_util.TensorFlowTestCase):
errors.FailedPreconditionError,
"Attempting to use uninitialized value Variable"):
with self.cached_session() as sess:
- v = variables.Variable([1, 2])
+ v = variables.VariableV1([1, 2])
sess.run(v[:].assign([1, 2]))
def testTypeError(self):
init_val = constant_op.constant([1, 2], dtype=dtypes.int32)
too_small_val = constant_op.constant([3, 4], dtype=dtypes.int8)
too_large_val = constant_op.constant([3, 4], dtype=dtypes.int64)
- v = variables.Variable(init_val)
+ v = variables.VariableV1(init_val)
with self.assertRaises(TypeError):
v[:].assign(too_small_val)
with self.assertRaises(TypeError):
@@ -1276,5 +1276,203 @@ class SnapshotOpTest(test_util.TensorFlowTestCase):
self.assertAllEqual(y.eval(), [0, 1, 2, 3])
+@test_util.run_all_in_graph_and_eager_modes
+class SortedSearchTest(test_util.TensorFlowTestCase):
+
+ def testUpperBoundFloatHandCoded(self):
+ cdf = np.array([0, .2, .5, .6, .8, 1.], dtype=np.float32)
+ arr = np.array([.04, .99, .53, .58, .31, .01, .79, .8, .21],
+ dtype=np.float32)
+ result = np.searchsorted(cdf, arr, side="right")
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+ self.assertAllEqual(result, tf_result)
+
+ def testUpperBoundFloatRandomNd(self):
+ dim_size = 7
+ for d in range(1, 5):
+ shape = [dim_size] * d
+ cdf = np.cumsum(
+ np.random.uniform(size=shape).astype(np.float32), axis=(d - 1))
+ arr = np.random.uniform(size=shape).astype(np.float32) * dim_size
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+
+ cdf = cdf.reshape([-1, dim_size])
+ arr = arr.reshape([-1, dim_size])
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(dim_size**(d - 1)):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
+
+ result = result.reshape(shape)
+
+ self.assertAllEqual(result, tf_result)
+
+ def testUpperBoundFloatUneven(self):
+ batch_size = 7
+ size_search_array = 1000
+ size_values = 47
+ cdf = np.cumsum(
+ np.random.uniform(size=[batch_size, size_search_array]).astype(
+ np.float32),
+ axis=1)
+ arr = np.random.uniform(size=[batch_size, size_values]).astype(
+ np.float32) * size_search_array
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(batch_size):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
+
+ self.assertAllEqual(result, tf_result)
+
+ def testLowerBoundFloatHandCoded(self):
+ cdf = np.array([0, .2, .5, .6, .8, 1.], dtype=np.float32)
+ arr = np.array([.04, .99, .53, .58, .31, .01, .79, .8, .21],
+ dtype=np.float32)
+ result = np.searchsorted(cdf, arr, side="left")
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+ self.assertAllEqual(result, tf_result)
+
+ def testLowerBoundFloatRandomNd(self):
+ dim_size = 7
+ for d in range(1, 5):
+ shape = [dim_size] * d
+ cdf = np.cumsum(
+ np.random.uniform(size=shape).astype(np.float32), axis=(d - 1))
+ arr = np.random.uniform(size=shape).astype(np.float32) * dim_size
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+
+ cdf = cdf.reshape([-1, dim_size])
+ arr = arr.reshape([-1, dim_size])
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(dim_size**(d - 1)):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
+
+ result = result.reshape(shape)
+
+ self.assertAllEqual(result, tf_result)
+
+ def testLowerBoundFloatUneven(self):
+ batch_size = 7
+ size_search_array = 1000
+ size_values = 47
+ cdf = np.cumsum(
+ np.random.uniform(size=[batch_size, size_search_array]).astype(
+ np.float32),
+ axis=1)
+ arr = np.random.uniform(size=[batch_size, size_values]).astype(
+ np.float32) * size_search_array
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(batch_size):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
+
+ self.assertAllEqual(result, tf_result)
+
+ def testUpperBoundIntHandCoded(self):
+ cdf = np.array([0, 20, 50, 60, 80, 100], dtype=np.int64)
+ arr = np.array([4, 99, 53, 58, 31, 1, 79, 8, 21], dtype=np.int64)
+ result = np.searchsorted(cdf, arr, side="right")
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+ self.assertAllEqual(result, tf_result)
+
+ def testUpperBoundIntRandomNd(self):
+ dim_size = 7
+ for d in range(1, 5):
+ shape = [dim_size] * d
+ cdf = np.cumsum(
+ np.random.randint(low=0, high=10, size=shape).astype(np.int64),
+ axis=(d - 1))
+ arr = np.random.randint(
+ low=0, high=10 * dim_size, size=shape).astype(np.int64)
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+
+ cdf = cdf.reshape([-1, dim_size])
+ arr = arr.reshape([-1, dim_size])
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(dim_size**(d - 1)):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
+
+ result = result.reshape(shape)
+
+ self.assertAllEqual(result, tf_result)
+
+ def testUpperBoundIntUneven(self):
+ batch_size = 7
+ size_search_array = 1000
+ size_values = 47
+ cdf = np.cumsum(
+ np.random.randint(low=0, high=10,
+ size=[batch_size,
+ size_search_array]).astype(np.int64),
+ axis=1)
+ arr = np.random.randint(
+ low=0, high=10 * size_search_array, size=[batch_size,
+ size_values]).astype(np.int64)
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(batch_size):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
+
+ self.assertAllEqual(result, tf_result)
+
+ def testLowerBoundIntHandCoded(self):
+ cdf = np.array([0, 20, 50, 60, 80, 100], dtype=np.int64)
+ arr = np.array([4, 99, 53, 58, 31, 1, 79, 8, 21], dtype=np.int64)
+ result = np.searchsorted(cdf, arr, side="left")
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+ self.assertAllEqual(result, tf_result)
+
+ def testLowerBoundIntRandomNd(self):
+ dim_size = 7
+ for d in range(1, 5):
+ shape = [dim_size] * d
+ cdf = np.cumsum(
+ np.random.randint(low=0, high=10, size=shape).astype(np.int64),
+ axis=(d - 1))
+ arr = np.random.randint(
+ low=0, high=10 * dim_size, size=shape).astype(np.int64)
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+
+ cdf = cdf.reshape([-1, dim_size])
+ arr = arr.reshape([-1, dim_size])
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(dim_size**(d - 1)):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
+
+ result = result.reshape(shape)
+
+ self.assertAllEqual(result, tf_result)
+
+ def testLowerBoundIntUneven(self):
+ batch_size = 7
+ size_search_array = 1000
+ size_values = 47
+ cdf = np.cumsum(
+ np.random.randint(low=0, high=10,
+ size=[batch_size,
+ size_search_array]).astype(np.int64),
+ axis=1)
+ arr = np.random.randint(
+ low=0, high=10 * size_search_array, size=[batch_size,
+ size_values]).astype(np.int64)
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(batch_size):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
+
+ self.assertAllEqual(result, tf_result)
+
+
if __name__ == "__main__":
test_lib.main()
diff --git a/tensorflow/python/kernel_tests/basic_gpu_test.py b/tensorflow/python/kernel_tests/basic_gpu_test.py
index e651fa0070..67e8618198 100644
--- a/tensorflow/python/kernel_tests/basic_gpu_test.py
+++ b/tensorflow/python/kernel_tests/basic_gpu_test.py
@@ -260,7 +260,7 @@ class GpuMultiSessionMemoryTest(test_util.TensorFlowTestCase):
threads = []
results = []
for _ in xrange(n_threads):
- session = self.test_session(graph=ops.Graph(), use_gpu=True)
+ session = self.session(graph=ops.Graph(), use_gpu=True)
results.append(set())
args = (session, results[-1])
threads.append(threading.Thread(target=self._run_session, args=args))
diff --git a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
index dee96102fb..7cdc67f83f 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
@@ -445,6 +445,78 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
# change= 0.1(1.14+7.0-7.0)
self.assertAllClose([[1], [0.114]], logits_updates)
+ def testCategoricalSplits(self):
+ """Tests the training prediction work for categorical splits."""
+ with self.cached_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ categorical_split {
+ feature_id: 1
+ value: 2
+ left_id: 1
+ right_id: 2
+ }
+ }
+ nodes {
+ categorical_split {
+ feature_id: 0
+ value: 13
+ left_id: 3
+ right_id: 4
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 6.0
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ is_finalized: true
+ }
+ """, tree_ensemble_config)
+
+ # Create existing ensemble with one root split
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ feature_0_values = [13, 1, 3]
+ feature_1_values = [2, 2, 1]
+
+ # No previous cached values.
+ cached_tree_ids = [0, 0, 0]
+ cached_node_ids = [0, 0, 0]
+
+ # Grow tree ensemble.
+ predict_op = boosted_trees_ops.training_predict(
+ tree_ensemble_handle,
+ cached_tree_ids=cached_tree_ids,
+ cached_node_ids=cached_node_ids,
+ bucketized_features=[feature_0_values, feature_1_values],
+ logits_dimension=1)
+
+ logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
+
+ self.assertAllClose([0, 0, 0], new_tree_ids)
+ self.assertAllClose([3, 4, 2], new_node_ids)
+ self.assertAllClose([[5.], [6.], [7.]], logits_updates)
+
def testCachedPredictionFromTheSameTreeWithPostPrunedNodes(self):
"""Tests that prediction based on previous node in the tree works."""
with self.cached_session() as session:
@@ -924,10 +996,229 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
logits = session.run(predict_op)
self.assertAllClose(expected_logits, logits)
+ def testCategoricalSplits(self):
+ """Tests the predictions work for categorical splits."""
+ with self.cached_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ categorical_split {
+ feature_id: 1
+ value: 2
+ left_id: 1
+ right_id: 2
+ }
+ }
+ nodes {
+ categorical_split {
+ feature_id: 0
+ value: 13
+ left_id: 3
+ right_id: 4
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 6.0
+ }
+ }
+ }
+ tree_weights: 1.0
+ """, tree_ensemble_config)
+
+ # Create existing ensemble with one root split
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ feature_0_values = [13, 1, 3]
+ feature_1_values = [2, 2, 1]
+
+ expected_logits = [[5.], [6.], [7.]]
+
+ # Prediction should work fine.
+ predict_op = boosted_trees_ops.predict(
+ tree_ensemble_handle,
+ bucketized_features=[feature_0_values, feature_1_values],
+ logits_dimension=1)
+
+ logits = session.run(predict_op)
+ self.assertAllClose(expected_logits, logits)
+
class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
"""Tests feature contribs ops for model understanding."""
+ def testContribsForOnlyABiasNode(self):
+ """Tests case when, after training, only left with a bias node.
+
+ For example, this could happen if the final ensemble contains one tree that
+ got pruned up to the root.
+ """
+ with self.cached_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ leaf {
+ scalar: 1.72
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata: {
+ num_layers_grown: 0
+ }
+ """, tree_ensemble_config)
+
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # All features are unused.
+ feature_0_values = [36, 32]
+ feature_1_values = [13, -29]
+ feature_2_values = [11, 27]
+
+ # Expected logits are computed by traversing the logit path and
+ # subtracting child logits from parent logits.
+ bias = 1.72 * 0.1 # Root node of tree_0.
+ expected_feature_ids = ((), ())
+ expected_logits_paths = ((bias,), (bias,))
+
+ bucketized_features = [
+ feature_0_values, feature_1_values, feature_2_values
+ ]
+
+ debug_op = boosted_trees_ops.example_debug_outputs(
+ tree_ensemble_handle,
+ bucketized_features=bucketized_features,
+ logits_dimension=1)
+
+ serialized_examples_debug_outputs = session.run(debug_op)
+ feature_ids = []
+ logits_paths = []
+ for example in serialized_examples_debug_outputs:
+ example_debug_outputs = boosted_trees_pb2.DebugOutput()
+ example_debug_outputs.ParseFromString(example)
+ feature_ids.append(example_debug_outputs.feature_ids)
+ logits_paths.append(example_debug_outputs.logits_path)
+
+ self.assertAllClose(feature_ids, expected_feature_ids)
+ self.assertAllClose(logits_paths, expected_logits_paths)
+
+ def testContribsMultipleTreeWhenFirstTreeIsABiasNode(self):
+ """Tests case when, after training, first tree contains only a bias node."""
+ with self.cached_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ leaf {
+ scalar: 1.72
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ threshold: 26
+ left_id: 1
+ right_id: 2
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 50
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ original_leaf: {scalar: 5.5}
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 6.0
+ }
+ }
+ }
+ tree_weights: 1.
+ tree_weights: 0.1
+ tree_metadata: {
+ num_layers_grown: 0
+ }
+ tree_metadata: {
+ num_layers_grown: 1
+ }
+ """, tree_ensemble_config)
+
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ feature_0_values = [36, 32]
+ feature_1_values = [13, -29] # Unused feature.
+ feature_2_values = [11, 27]
+
+ # Expected logits are computed by traversing the logit path and
+ # subtracting child logits from parent logits.
+ expected_feature_ids = ((2, 0), (2,))
+ # bias = 1.72 * 1. # Root node of tree_0.
+ # example_0 : (bias, 0.1 * 5.5 + bias, 0.1 * 5. + bias)
+ # example_1 : (bias, 0.1 * 7. + bias )
+ expected_logits_paths = ((1.72, 2.27, 2.22), (1.72, 2.42))
+
+ bucketized_features = [
+ feature_0_values, feature_1_values, feature_2_values
+ ]
+
+ debug_op = boosted_trees_ops.example_debug_outputs(
+ tree_ensemble_handle,
+ bucketized_features=bucketized_features,
+ logits_dimension=1)
+
+ serialized_examples_debug_outputs = session.run(debug_op)
+ feature_ids = []
+ logits_paths = []
+ for example in serialized_examples_debug_outputs:
+ example_debug_outputs = boosted_trees_pb2.DebugOutput()
+ example_debug_outputs.ParseFromString(example)
+ feature_ids.append(example_debug_outputs.feature_ids)
+ logits_paths.append(example_debug_outputs.logits_path)
+
+ self.assertAllClose(feature_ids, expected_feature_ids)
+ self.assertAllClose(logits_paths, expected_logits_paths)
+
def testContribsMultipleTree(self):
"""Tests that the contribs work when we have multiple trees."""
with self.cached_session() as session:
@@ -1018,11 +1309,14 @@ class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
tree_weights: 0.2
tree_weights: 1.0
tree_metadata: {
- num_layers_grown: 1}
+ num_layers_grown: 1
+ }
tree_metadata: {
- num_layers_grown: 2}
+ num_layers_grown: 2
+ }
tree_metadata: {
- num_layers_grown: 1}
+ num_layers_grown: 1
+ }
""", tree_ensemble_config)
tree_ensemble = boosted_trees_ops.TreeEnsemble(
diff --git a/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
index c71b8df4ad..e0d46bae83 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
@@ -78,7 +78,7 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
self.num_quantiles = constant_op.constant(3, dtype=dtypes.int64)
def testBasicQuantileBucketsSingleResource(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
quantile_accumulator_handle = self.create_resource("floats", self.eps,
self.max_elements, 2)
resources.initialize_resources(resources.shared_resources()).run()
@@ -102,7 +102,7 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
def testBasicQuantileBucketsMultipleResources(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
quantile_accumulator_handle_0 = self.create_resource("float_0", self.eps,
self.max_elements)
quantile_accumulator_handle_1 = self.create_resource("float_1", self.eps,
diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
index bd2339f31d..09c325f2bc 100644
--- a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
+++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
@@ -90,7 +90,7 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
x = constant_op.constant(1, dtype=dtypes.float32)
v = array_ops.broadcast_to(x, [2, 4, 3])
out = 2 * v
- with self.test_session():
+ with self.cached_session():
err = gradient_checker.compute_gradient_error(x, x.get_shape(),
out, out.get_shape())
self.assertLess(err, 1e-4)
@@ -100,7 +100,7 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
dtype=dtypes.float32)
v = array_ops.broadcast_to(x, [2, 5, 3])
out = 2 * v
- with self.test_session():
+ with self.cached_session():
err = gradient_checker.compute_gradient_error(x, x.get_shape(),
out, out.get_shape())
self.assertLess(err, 1e-4)
@@ -110,7 +110,7 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
dtype=dtypes.float32)
v = array_ops.broadcast_to(x, [5, 2, 3])
out = 2 * v
- with self.test_session():
+ with self.cached_session():
err = gradient_checker.compute_gradient_error(x, x.get_shape(),
out, out.get_shape())
self.assertLess(err, 1e-4)
@@ -119,7 +119,7 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
x = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32)
v = array_ops.broadcast_to(x, [5, 4, 6])
out = 2 * v
- with self.test_session():
+ with self.cached_session():
err = gradient_checker.compute_gradient_error(x, x.get_shape(),
out, out.get_shape())
self.assertLess(err, 1e-4)
diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py
index 27a674e223..bd4011d58e 100644
--- a/tensorflow/python/kernel_tests/check_ops_test.py
+++ b/tensorflow/python/kernel_tests/check_ops_test.py
@@ -785,7 +785,7 @@ class EnsureShapeTest(test.TestCase):
derived = math_ops.divide(placeholder, 3, name="MyDivide")
derived = check_ops.ensure_shape(derived, (3, 3, 3))
feed_val = [[1], [2]]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
r"Shape of tensor MyDivide \[2,1\] is not compatible with "
@@ -797,7 +797,7 @@ class EnsureShapeTest(test.TestCase):
derived = placeholder / 3
derived = check_ops.ensure_shape(derived, (None, None, 3))
feed_val = [[1], [2]]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
r"Shape of tensor [A-Za-z_]* \[2,1\] is not compatible with "
@@ -809,7 +809,7 @@ class EnsureShapeTest(test.TestCase):
derived = placeholder / 3
derived = check_ops.ensure_shape(derived, (2, 1))
feed_val = [[1], [2]]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(derived, feed_dict={placeholder: feed_val})
def testEnsuresDynamicShape_WithUnknownDims(self):
@@ -817,7 +817,7 @@ class EnsureShapeTest(test.TestCase):
derived = placeholder / 3
derived = check_ops.ensure_shape(derived, (None, None))
feed_val = [[1], [2]]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(derived, feed_dict={placeholder: feed_val})
def testGradient(self):
@@ -826,7 +826,7 @@ class EnsureShapeTest(test.TestCase):
gradient = gradients.gradients(derived, placeholder)
feed_val = [[4.0], [-1.0]]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gradient_values, = sess.run(gradient, feed_dict={placeholder: feed_val})
expected = [[1.0], [1.0]]
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index a1efecf28a..377c041675 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -41,7 +41,7 @@ class CondV2Test(test.TestCase):
def _testCond(self, true_fn, false_fn, train_vals, feed_dict=None):
if not feed_dict:
feed_dict = {}
- with self.test_session(graph=ops.get_default_graph()) as sess:
+ with self.session(graph=ops.get_default_graph()) as sess:
pred = array_ops.placeholder(dtypes.bool, name="pred")
expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected")
@@ -131,7 +131,7 @@ class CondV2Test(test.TestCase):
def false_fn():
return x + 1
- return cond_v2.cond_v2(pred, true_fn, false_fn, name=name)[0].op
+ return cond_v2.cond_v2(pred, true_fn, false_fn, name=name).op
def testDefaultName(self):
with ops.Graph().as_default():
@@ -382,7 +382,7 @@ class CondV2Test(test.TestCase):
with ops.Graph().as_default():
grads, pred_outer, pred_inner = build_graph()
- with self.test_session(graph=ops.get_default_graph()) as sess:
+ with self.session(graph=ops.get_default_graph()) as sess:
self.assertSequenceEqual(
sess.run(grads, {
pred_outer: True,
@@ -445,7 +445,7 @@ class CondV2Test(test.TestCase):
with ops.Graph().as_default():
grads, pred_outer, pred_inner = build_graph()
- with self.test_session(graph=ops.get_default_graph()) as sess:
+ with self.session(graph=ops.get_default_graph()) as sess:
self.assertSequenceEqual(
sess.run(grads, {
pred_outer: True,
@@ -504,7 +504,7 @@ class CondV2Test(test.TestCase):
with ops.Graph().as_default():
grads, pred_outer, pred_inner = build_graph()
- with self.test_session(graph=ops.get_default_graph()) as sess:
+ with self.session(graph=ops.get_default_graph()) as sess:
self.assertSequenceEqual(
sess.run(grads, {
pred_outer: True,
@@ -569,12 +569,11 @@ class CondV2Test(test.TestCase):
ops.add_to_collection("pred", pred)
cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
- for c in cond:
- ops.add_to_collection("cond", c)
+ ops.add_to_collection("cond", cond)
meta_graph = saver.export_meta_graph()
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
saver.import_meta_graph(meta_graph)
x = ops.get_collection("x")[0]
pred = ops.get_collection("pred")[0]
@@ -598,7 +597,7 @@ class CondV2Test(test.TestCase):
def testLowering(self):
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
out_cond = self._createCond("cond")
run_options = config_pb2.RunOptions(output_partition_graphs=True)
@@ -624,7 +623,7 @@ class CondV2Test(test.TestCase):
"An `If` op was found, but it should be lowered.")
def testLoweringDisabledInXLA(self):
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
# Build the cond_v2 in an XLA context
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
@@ -661,7 +660,7 @@ class CondV2CollectionTest(test.TestCase):
def testCollectionIntValueAccessInCond(self):
"""Read values from graph collections inside of cond_v2."""
with ops.Graph().as_default() as g:
- with self.test_session(graph=g):
+ with self.session(graph=g):
x = 2
y = 5
ops.add_to_collection("x", x)
@@ -672,12 +671,12 @@ class CondV2CollectionTest(test.TestCase):
return math_ops.add(x_const, y_const)
cnd = cond_v2.cond_v2(True, fn, fn)
- self.assertEquals(cnd[0].eval(), 7)
+ self.assertEquals(cnd.eval(), 7)
def testCollectionTensorValueAccessInCond(self):
"""Read tensors from collections inside of cond_v2 & use them."""
with ops.Graph().as_default() as g:
- with self.test_session(graph=g):
+ with self.session(graph=g):
x = constant_op.constant(2)
y = constant_op.constant(5)
ops.add_to_collection("x", x)
@@ -689,12 +688,12 @@ class CondV2CollectionTest(test.TestCase):
return math_ops.add(x_read, y_read)
cnd = cond_v2.cond_v2(math_ops.less(x, y), fn, fn)
- self.assertEquals(cnd[0].eval(), 7)
+ self.assertEquals(cnd.eval(), 7)
def testCollectionIntValueWriteInCond(self):
"""Make sure Int writes to collections work inside of cond_v2."""
with ops.Graph().as_default() as g:
- with self.test_session(graph=g):
+ with self.session(graph=g):
x = constant_op.constant(2)
y = constant_op.constant(5)
def true_fn():
@@ -709,7 +708,7 @@ class CondV2CollectionTest(test.TestCase):
cnd = cond_v2.cond_v2(
True, true_fn,
false_fn)
- self.assertEquals(cnd[0].eval(), 14)
+ self.assertEquals(cnd.eval(), 14)
read_z_collection = ops.get_collection("z")
self.assertEquals(read_z_collection, [7])
@@ -725,7 +724,7 @@ class CondV2ContainerTest(test.TestCase):
"""
self.skipTest("b/113048653")
with ops.Graph().as_default() as g:
- with self.test_session(graph=g):
+ with self.session(graph=g):
v0 = variables.Variable([0])
q0 = data_flow_ops.FIFOQueue(1, dtypes.float32)
@@ -782,10 +781,10 @@ class CondV2ContainerTest(test.TestCase):
with ops.container("l1"):
cnd_true = cond_v2.cond_v2(True, true_fn, false_fn)
- self.assertEquals(cnd_true[0].eval(), 2)
+ self.assertEquals(cnd_true.eval(), 2)
cnd_false = cond_v2.cond_v2(False, true_fn, false_fn)
- self.assertEquals(cnd_false[0].eval(), 6)
+ self.assertEquals(cnd_false.eval(), 6)
v4 = variables.Variable([3])
q4 = data_flow_ops.FIFOQueue(1, dtypes.float32)
@@ -802,7 +801,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
def testColocateWithBeforeCond(self):
with ops.Graph().as_default() as g:
- with self.test_session(graph=g):
+ with self.session(graph=g):
a = constant_op.constant([2.0], name="a")
b = constant_op.constant([2.0], name="b")
@@ -813,7 +812,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.colocate_with(a.op):
- self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3)
+ self.assertEquals(cond_v2.cond_v2(True, fn, fn).eval(), 3)
def fn2():
c = constant_op.constant(3.0)
@@ -822,11 +821,11 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
with ops.colocate_with(a.op):
with ops.colocate_with(b.op):
- self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)
def testColocateWithInAndOutOfCond(self):
with ops.Graph().as_default() as g:
- with self.test_session(graph=g):
+ with self.session(graph=g):
a = constant_op.constant([2.0], name="a")
b = constant_op.constant([2.0], name="b")
@@ -838,7 +837,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.colocate_with(a.op):
- self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)
d = constant_op.constant([2.0], name="d")
self.assertEqual([b"loc:@a"], d.op.colocation_groups())
@@ -859,7 +858,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
with ops.colocate_with(b.op):
c = math_ops.add(a, a, name="c")
return c
- out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0]
+ out_cond_2 = cond_v2.cond_v2(True, fn, fn)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
run_metadata = config_pb2.RunMetadata()
@@ -873,14 +872,15 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
def testDeviceBeforeCond(self):
with ops.Graph().as_default() as g:
- with self.test_session(graph=g):
+ with self.session(graph=g):
+
def fn():
c = constant_op.constant(3.0)
self.assertEqual("/device:CPU:0", c.op.device)
return c
with ops.device("/device:CPU:0"):
- self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3)
+ self.assertEquals(cond_v2.cond_v2(True, fn, fn).eval(), 3)
def fn2():
c = constant_op.constant(3.0)
@@ -888,7 +888,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.device("/device:GPU:0"):
- self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)
def testDeviceInAndOutOfCond(self):
with ops.Graph().as_default() as g:
@@ -902,7 +902,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.device("/device:CPU:0"):
- self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)
d = constant_op.constant(4.0)
self.assertEqual("/device:CPU:0", d.op.device)
@@ -921,7 +921,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
with ops.device("/device:CPU:0"):
a = constant_op.constant([2.0], name="a")
- out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0]
+ out_cond_2 = cond_v2.cond_v2(True, fn, fn)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
run_metadata = config_pb2.RunMetadata()
diff --git a/tensorflow/python/kernel_tests/conditional_accumulator_test.py b/tensorflow/python/kernel_tests/conditional_accumulator_test.py
index 262352a9af..97ab23fe49 100644
--- a/tensorflow/python/kernel_tests/conditional_accumulator_test.py
+++ b/tensorflow/python/kernel_tests/conditional_accumulator_test.py
@@ -272,7 +272,7 @@ class ConditionalAccumulatorTest(test.TestCase):
self.assertEqual(15.0, val)
def testAccumulatorTakeGradSum(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32,
name="Q",
@@ -349,7 +349,7 @@ class ConditionalAccumulatorTest(test.TestCase):
self.assertEqual(elems_ave + 0.0, val)
def testAccumulatorRepeatedTakeGradSum(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32,
name="Q",
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index ebeabcfe1a..d91a848e01 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -23,7 +23,6 @@ from __future__ import print_function
import collections
import math
import time
-import unittest
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -63,6 +62,7 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.ops import while_v2 # pylint: disable=unused-import
# pylint: disable=unused-import
import tensorflow.python.ops.tensor_array_grad
# pylint: enable=unused-import
@@ -125,12 +125,12 @@ def isum(s, maximum_iterations=None):
return r_s
-@test_util.with_cond_v2
+@test_util.with_control_flow_v2
class ControlFlowTest(test.TestCase):
def testRefIdentity(self):
with self.cached_session():
- v = variables.Variable(7)
+ v = variables.VariableV1(7)
v = control_flow_ops._Identity(v)
op = state_ops.assign(v, 9)
@@ -142,7 +142,7 @@ class ControlFlowTest(test.TestCase):
def testRefEnter(self):
with self.cached_session():
- v = variables.Variable(7)
+ v = variables.VariableV1(7)
enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True)
nine = constant_op.constant(9)
@@ -155,7 +155,7 @@ class ControlFlowTest(test.TestCase):
def testRefSwitch(self):
with self.cached_session():
- v = variables.Variable(7)
+ v = variables.VariableV1(7)
p = constant_op.constant(True)
v1 = control_flow_ops._SwitchRefOrTensor(v._ref(), p) # pylint: disable=protected-access
@@ -332,10 +332,8 @@ class ControlFlowTest(test.TestCase):
with self.assertRaisesOpError("has inputs from different frames"):
res.eval(feed_dict={data: 1.0})
+ @test_util.disable_control_flow_v2("b/113294340")
def testCondBool(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113296297")
-
values = constant_op.constant(10)
fn1 = lambda: math_ops.add(values, 1)
fn2 = lambda: math_ops.subtract(values, 1)
@@ -366,6 +364,7 @@ class ControlFlowTest(test.TestCase):
"has been marked as not fetchable"):
sess.run(t, feed_dict={x: 3})
+ @test_util.disable_control_flow_v2("Not relevant")
def testFeedable(self):
with self.cached_session() as sess:
c = constant_op.constant(2)
@@ -383,10 +382,8 @@ class ControlFlowTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, "may not be fed"):
sess.run(r, feed_dict={t: 3})
+ @test_util.disable_control_flow_v2("b/113296180 (IndexedSlices)")
def testCondIndexedSlices(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113296180")
-
with self.cached_session():
values = constant_op.constant(10)
indices = constant_op.constant(0)
@@ -401,10 +398,8 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(11, val)
self.assertAllEqual(0, ind)
+ @test_util.disable_control_flow_v2("b/113296161 (SparseTensors)")
def testCondSparseTensor(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113296161 (SparseTensors)")
-
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
indices = constant_op.constant(
@@ -422,8 +417,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(r.values.get_shape(), (2,))
def testCondResource(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
rv = resource_variable_ops.ResourceVariable(True)
@@ -437,10 +430,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(1.0, control_flow_ops.cond(rv, case, lambda: t).eval())
+ @test_util.disable_control_flow_v2("b/113293074")
def testCondIndexedSlicesDifferentTypes(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113293074")
-
with self.cached_session():
values = constant_op.constant(10)
i_32 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int32)
@@ -484,15 +475,12 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(11, result)
def testCond_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
self._testCond_1(use_gpu=False)
- self._testCond_1(use_gpu=True)
+ # TODO(b/116526896): Enable GPU tests.
+ # self._testCond_1(use_gpu=True)
def testCond_2(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
x = constant_op.constant(10)
@@ -503,8 +491,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(9, result)
def testCond_3(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
x = constant_op.constant(10)
@@ -517,10 +503,8 @@ class ControlFlowTest(test.TestCase):
result = r.eval()
self.assertAllEqual(12, result)
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testCond_4(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113324949 (ref vars)")
-
with self.cached_session():
v1 = variables.Variable(7)
v2 = variables.Variable(7)
@@ -556,8 +540,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(4, count.eval())
def testCond_6(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
v1 = variables.Variable([7])
@@ -583,8 +565,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual([11, 12], sess.run(r))
def testCondRef(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
x = gen_state_ops.variable(
@@ -598,10 +578,8 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(constant_op.constant(False), true_fn, false_fn)
self.assertAllEqual([2.0], r.eval())
+ @test_util.disable_control_flow_v2("b/79881896 (control deps)")
def testCondWithControl(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/79881896")
-
with self.cached_session():
control_holder = array_ops.placeholder(dtypes.float32, shape=())
a = constant_op.constant(3)
@@ -640,10 +618,9 @@ class ControlFlowTest(test.TestCase):
merged_op = control_flow_ops.merge([assign_v, orig_v])
self.assertAllEqual([1.0], sess.run(merged_op.output))
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCondSwitchIdentity(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
# Make sure the recv identity is not removed by optimization.
with session.Session(config=opt_cfg()) as sess:
pred = constant_op.constant(True)
@@ -657,10 +634,9 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
sess.run(r)
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCondRecvIdentity(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
# Make sure the switch identity is not removed by optimization.
with session.Session(config=opt_cfg()) as sess:
with ops.device(test.gpu_device_name()):
@@ -676,10 +652,8 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
sess.run(r)
+ @test_util.disable_control_flow_v2("b/113346829 (gpu failure)")
def testCondGrad_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113346829 (gpu failure)")
-
graph = ops.Graph()
with graph.as_default():
x = constant_op.constant(10.0, name="x")
@@ -705,10 +679,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1}))
self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3}))
+ @test_util.disable_control_flow_v2(
+ "b/110550782 (gradient w.r.t external variable)")
def testCondGrad_3(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/110550782 (gradient w.r.t external variable)")
-
with self.cached_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
ox = constant_op.constant(10.0)
@@ -740,10 +713,8 @@ class ControlFlowTest(test.TestCase):
result = gradients_impl.gradients(z, x)[0]
self.assertEqual(1.0, result.eval())
+ @test_util.disable_control_flow_v2("b/113327884")
def testCondGrad_Gather(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113327884")
-
with self.cached_session() as sess:
v1 = variables.Variable([1.0, 42.0])
c = array_ops.placeholder(dtypes.int32, shape=[])
@@ -767,6 +738,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(dense_gv, [0.0, 2.0])
# Microbenchmark: 256,000 iterations/s.
+ @test_util.disable_control_flow_v2("b/116630618 (Times out)")
def testWhile_1(self):
with self.cached_session():
n = constant_op.constant(0)
@@ -775,6 +747,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
self.assertEqual(10000, r.eval())
+ @test_util.disable_control_flow_v2("b/79881896 (control deps)")
def testWhileExternalControlDependencies(self):
with self.cached_session():
v = variables.Variable(0.0)
@@ -790,6 +763,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(result.eval(), 2)
self.assertAllEqual(v.eval(), 1.0)
+ @test_util.disable_control_flow_v2("b/79881896 (control deps)")
def testWhileExternalControlDependenciesNoInput(self):
with self.cached_session():
v = variables.Variable(0.0)
@@ -805,9 +779,10 @@ class ControlFlowTest(test.TestCase):
result.eval()
self.assertAllEqual(v.eval(), 1.0)
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileWithRefs_1(self):
with self.cached_session() as sess:
- x = variables.Variable(0)._ref() # pylint: disable=protected-access
+ x = variables.VariableV1(0)._ref() # pylint: disable=protected-access
i = constant_op.constant(0)
c = lambda i, x: math_ops.less(i, 100)
@@ -835,18 +810,22 @@ class ControlFlowTest(test.TestCase):
r = isum(s)
self.assertAllEqual(45, r.eval())
+ @test_util.disable_control_flow_v2("b/115776323 (max_iters)")
def testWhileWithMaximumIterations(self):
with self.cached_session():
s = constant_op.constant([1, 2, 3, 4, 5])
r = isum(s, maximum_iterations=3)
self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], r.eval())
+ @test_util.disable_control_flow_v2("b/116339888 (non-tensor loop var)")
def testWhileWithMaximumIterationsAndSingleArgument(self):
with self.cached_session():
r = control_flow_ops.while_loop(
lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
self.assertEqual(1, r.eval())
+ @test_util.disable_control_flow_v2(
+ "b/116248044 (nested), b/115920078 (gradients)")
def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self):
v = constant_op.constant(1.0)
@@ -872,6 +851,7 @@ class ControlFlowTest(test.TestCase):
# Should execute without issue.
self.assertEqual(3, self.evaluate(loop_execute))
+ @test_util.disable_control_flow_v2("b/116248044 (nested while_loop)")
def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self):
v = constant_op.constant(1.0)
@@ -915,10 +895,8 @@ class ControlFlowTest(test.TestCase):
r"context '.*' \(currently defined in '.*'\)"):
_ = gradients_impl.gradients(loop_with_maxiter, v)
+ @test_util.disable_control_flow_v2("b/115776323 (max_iters)")
def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
-
v = constant_op.constant(1.0)
def create_while_loop():
@@ -950,6 +928,8 @@ class ControlFlowTest(test.TestCase):
r"while loop context '' \(currently defined in 'cond/.+'\)"):
_ = gradients_impl.gradients(loop, v)
+ @test_util.disable_control_flow_v2(
+ "b/116248044 (nesting), b/115776323 (max_iters)")
def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self):
v = constant_op.constant(1.0)
@@ -1059,6 +1039,7 @@ class ControlFlowTest(test.TestCase):
result = r[3].eval()
self.assertAllEqual(42, result)
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhile_5(self):
with self.cached_session():
@@ -1083,6 +1064,7 @@ class ControlFlowTest(test.TestCase):
result = r[2].eval()
self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result)
+ @test_util.disable_control_flow_v2("b/116338794 (buffer_reuse)")
def testBufferForwarding(self):
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
@@ -1133,6 +1115,7 @@ class ControlFlowTest(test.TestCase):
self._testWhile_Gpu_1(use_gpu=False)
self._testWhile_Gpu_1(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileShape(self):
with self.cached_session():
i = constant_op.constant(0)
@@ -1150,6 +1133,7 @@ class ControlFlowTest(test.TestCase):
r = r[1] * array_ops.ones([8, 8])
self.assertAllEqual(np.ones((8, 8)), r.eval())
+ @test_util.disable_control_flow_v2("b/116339888 (non-tensor loop var)")
def testWhileWithNonTensorInput_Scalar(self):
with self.cached_session():
n = 0
@@ -1158,6 +1142,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
self.assertEqual(10000, r.eval())
+ @test_util.disable_control_flow_v2("b/116339888 (non-tensor loop var)")
def testWhileWithNonTensorInput_Vector(self):
with self.cached_session():
n = np.array([0]) # Note, [0] would not work here; that is a list
@@ -1166,6 +1151,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
self.assertEqual([10000], r.eval())
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileShapeInference(self):
with self.cached_session():
i = constant_op.constant(0)
@@ -1180,7 +1166,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(
c, b, [i, m],
[i.get_shape(), tensor_shape.TensorShape([None, 2])])
- self.assertTrue(r[1].get_shape()[0].value is None)
+ self.assertIsNone(r[1].get_shape()[0].value)
self.assertEqual(r[1].get_shape()[1], tensor_shape.Dimension(2))
with self.assertRaisesRegexp(
@@ -1191,6 +1177,7 @@ class ControlFlowTest(test.TestCase):
r"tf.while_loop to specify a less-specific shape."):
r = control_flow_ops.while_loop(c, b, [i, m])
+ @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)")
def testWhileShapeInferenceSparseTensor(self):
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
@@ -1222,6 +1209,7 @@ class ControlFlowTest(test.TestCase):
c, b, [i, x],
[i.get_shape(), tensor_shape.TensorShape([5])])
+ @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)")
def testWhileShapeInferenceIndexedSlices(self):
with self.cached_session():
values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values")
@@ -1276,6 +1264,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n])
self.assertEqual(225, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testNestedWhile_1(self):
self._testNestedWhile_1(use_gpu=False)
self._testNestedWhile_1(use_gpu=True)
@@ -1308,6 +1297,7 @@ class ControlFlowTest(test.TestCase):
outer_c, outer_b, [s0], parallel_iterations=1)
self.assertEqual(1048576.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testNestedWhile_2(self):
self._testNestedWhile_2(use_gpu=False)
self._testNestedWhile_2(use_gpu=True)
@@ -1361,6 +1351,7 @@ class ControlFlowTest(test.TestCase):
lambda x: x < 10, lambda x: x + array_ops.identity(c), [x0])
self.assertEqual(10, sess.run(r, {b: True}))
+ @test_util.disable_control_flow_v2("b/79881896 (control_deps)")
def testWhileWithControl_5(self):
with self.cached_session() as sess:
b = array_ops.placeholder(dtypes.bool)
@@ -1375,9 +1366,6 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10, sess.run(r, {b: True}))
def testWhileCondWithControl(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
-
# Ensure that no control edges by an outer control dependency context are
# added to nodes inside cond/while contexts.
with self.cached_session() as sess:
@@ -1391,10 +1379,8 @@ class ControlFlowTest(test.TestCase):
(constant_op.constant(5),))
self.assertEqual(0, sess.run(loop))
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testWhileCondWithControl_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113324949 (ref vars)")
-
with self.cached_session():
v = variable_scope.get_variable(
"v", [], initializer=init_ops.constant_initializer(2))
@@ -1416,9 +1402,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(4, r.eval())
self.assertAllClose(65536.0, v.eval())
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testWhileCondExitControl(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
with self.cached_session():
v = variables.Variable(1)
@@ -1443,8 +1428,6 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(99, v.eval())
def testCondWhile_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
n = ops.convert_to_tensor(0, name="n")
@@ -1456,8 +1439,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testCondWhile_2(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
n = ops.convert_to_tensor(0)
@@ -1469,9 +1450,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def _testCondWhile_3(self, use_gpu):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
-
with self.test_session(use_gpu=use_gpu) as sess:
p = array_ops.placeholder(dtypes.bool)
n = constant_op.constant(0.0)
@@ -1488,18 +1466,17 @@ class ControlFlowTest(test.TestCase):
lambda: control_flow_ops.while_loop(c, b, [n]),
lambda: math_ops.multiply(n, 2.0))
r1 = gradients_impl.gradients(r, [n])
- self.assertEqual(10, sess.run(r, {p: True}))
+ self.assertEqual(10., sess.run(r, {p: True}))
self.assertEqual([1.0], sess.run(r1, {p: True}))
self.assertEqual(0.0, sess.run(r, {p: False}))
self.assertEqual([2.0], sess.run(r1, {p: False}))
+ @test_util.disable_control_flow_v2("b/116743589")
def testCondWhile_3(self):
self._testCondWhile_3(use_gpu=False)
self._testCondWhile_3(use_gpu=True)
def testWhileCond_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.cached_session():
i = ops.convert_to_tensor(0, name="i")
@@ -1516,8 +1493,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testWhileCond_2(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.cached_session():
n = ops.convert_to_tensor(0, name="n")
@@ -1527,8 +1502,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testWhileCond_3(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.cached_session():
n = ops.convert_to_tensor(0)
@@ -1543,6 +1516,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
# NOTE: It is ok to have parallel_iterations > 1
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_1(self):
with self.cached_session():
select = variables.Variable([3.0, 4.0, 5.0])
@@ -1565,6 +1539,7 @@ class ControlFlowTest(test.TestCase):
result = select.eval()
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_2(self):
with self.cached_session():
select1 = variables.Variable([3.0, 4.0, 5.0])
@@ -1591,6 +1566,7 @@ class ControlFlowTest(test.TestCase):
result2 = select2.eval()
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2)
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_3(self):
with self.cached_session():
select = variables.Variable([3.0, 4.0, 5.0])
@@ -1612,7 +1588,7 @@ class ControlFlowTest(test.TestCase):
result = r[1].eval()
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
- # b/24814703
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_4(self):
with self.cached_session():
var_a = variables.Variable(0, name="a")
@@ -1640,7 +1616,7 @@ class ControlFlowTest(test.TestCase):
lpa.eval() # Run the loop
self.assertEqual(10, var_b.eval())
- # b/24736492
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_5(self):
with self.cached_session():
# Create some variables.
@@ -1670,7 +1646,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10, var_a.eval())
self.assertEqual(10, var_b.eval())
- # b/24814668
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_6(self):
with self.cached_session():
# Create some variables.
@@ -1700,6 +1676,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(55, var_b.eval())
self.assertEqual(10, var_a.eval())
+ @test_util.disable_control_flow_v2("b/116742472 (resource accumulator)")
def testWhileQueue_1(self):
with self.cached_session():
q = data_flow_ops.FIFOQueue(-1, dtypes.int32)
@@ -1718,6 +1695,7 @@ class ControlFlowTest(test.TestCase):
for i in xrange(10):
self.assertEqual([i], q.dequeue().eval())
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileStack_1(self):
with self.cached_session():
s = gen_data_flow_ops.stack_v2(-1, dtypes.int32, stack_name="foo")
@@ -1783,9 +1761,10 @@ class ControlFlowTest(test.TestCase):
else:
self.assertFalse(gpu_dev_name in dev)
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
self.assertAllClose(1024.0, sess.run(r))
+ @test_util.disable_control_flow_v2("b/116351701 (colocation)")
def testWhileGrad_ColocateGradients(self):
self._testWhileGrad_ColocateGradients(colocate=False)
self._testWhileGrad_ColocateGradients(colocate=True)
@@ -1801,6 +1780,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(1024.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileGrad_Shape(self):
with self.cached_session():
x = array_ops.placeholder(dtypes.float32, shape=[None])
@@ -1872,8 +1852,6 @@ class ControlFlowTest(test.TestCase):
self._testWhileGrad_Mul(use_gpu=True, p_iters=10)
def _testNestedWhileCondWhileGrad(self, use_gpu):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.test_session(use_gpu=use_gpu):
v = constant_op.constant(1.0)
@@ -1896,10 +1874,12 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(512.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testNestedWhileCondWhileGrad(self):
self._testNestedWhileCondWhileGrad(use_gpu=False)
self._testNestedWhileCondWhileGrad(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116823782")
def testWhileGrad_Variable(self):
with self.cached_session():
a = variables.Variable(3.0)
@@ -1913,8 +1893,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(216.0, r[0].eval())
def testWhileGradInCond(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/110550782 (gradient w.r.t external variable)")
with self.cached_session():
n = ops.convert_to_tensor(1.0, name="n")
@@ -1930,6 +1908,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x)
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
+ @test_util.disable_control_flow_v2("b/116340060")
def testGradInWhileWrtInitialLoopVal(self):
with self.cached_session():
x = array_ops.placeholder(dtypes.float32, shape=(), name="x")
@@ -1947,6 +1926,7 @@ class ControlFlowTest(test.TestCase):
"loop invariants or wrt the input parameters to the loop body."):
control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y])
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testWhileGradInWhile(self):
with self.cached_session():
n = ops.convert_to_tensor(1.0, name="n")
@@ -1963,9 +1943,8 @@ class ControlFlowTest(test.TestCase):
[tensor_shape.unknown_shape()])
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testCondGradInNestedWhiles(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113346829 (gpu failure)")
def outer_body(i, x):
_, x = control_flow_ops.while_loop(
@@ -1983,6 +1962,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(i_val, 3)
self.assertAllClose(x_val, 1.0)
+ @test_util.disable_control_flow_v2("b/116255781 (flat_args)")
def testWhile_NestedInput(self):
with self.cached_session() as sess:
named = collections.namedtuple("named", ("a", "b"))
@@ -2010,6 +1990,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual([100.0, 1.0, 102.0, 3.0, 4.0 + 100 * 2.0],
sess.run(r_flattened))
+ @test_util.disable_control_flow_v2("b/116255781(flat_args)")
def testWhile_NestedBadArityFails(self):
with self.cached_session():
named = collections.namedtuple("named", ("a", "b"))
@@ -2068,6 +2049,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients([rx], x)
self.assertAllClose(1024.0, r[0].eval())
+ @test_util.disable_control_flow_v2("b/116355153 (back_prop flag)")
def testWhileGrad_NoGradient(self):
with self.cached_session():
v = constant_op.constant(2.0, name="v")
@@ -2078,6 +2060,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)
self.assertAllClose(1.0, r[0].eval())
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileGrad_NoDependency(self):
with self.cached_session() as sess:
variable = variables.Variable(array_ops.ones([2, 3]))
@@ -2191,10 +2174,12 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(8.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested)")
def testNestedWhileGrad_Simple(self):
self._testNestedWhileGrad_Simple(use_gpu=False)
self._testNestedWhileGrad_Simple(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116248044 (nested)")
def testNestedWhileGrad_SerialInner(self):
with self.cached_session():
v = constant_op.constant(1.0)
@@ -2218,6 +2203,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(256.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested)")
def testNestedWhileGrad_ParallelInner(self):
with self.cached_session():
v = constant_op.constant(1.0)
@@ -2241,6 +2227,8 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(512.0, r.eval())
+ @test_util.disable_control_flow_v2(
+ "Nested loops and TensorArrays not supported")
def testNestedWhileGrad_ParallelIterations(self):
# Make sure the stack pushes and pops of an inner loop are executed in
# the sequential order of the iterations of its outer loop.
@@ -2279,13 +2267,12 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(1024.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116272044 (cond_in_while)")
def testWhileCondGrad_Simple(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
-
self._testWhileCondGrad_Simple(use_gpu=False)
self._testWhileCondGrad_Simple(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116272044 (cond_in_while)")
def testWhileCondGrad_UnknownShape(self):
with self.cached_session() as sess:
v = array_ops.placeholder(dtypes.float32)
@@ -2303,6 +2290,7 @@ class ControlFlowTest(test.TestCase):
r = sess.run(r, feed_dict={v: 2.0})
self.assertAllClose(1024.0, r)
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileGrad_Concat(self):
with self.cached_session() as sess:
x = variable_scope.get_variable("x", initializer=[[1., 2.]])
@@ -2326,9 +2314,10 @@ class ControlFlowTest(test.TestCase):
sess.run(op)
self.assertAllClose([[0.98000002, 1.98000002]], sess.run(x))
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileWithRefsWithGradients_1(self):
with self.cached_session() as sess:
- x = variables.Variable(0.)._ref() # pylint: disable=protected-access
+ x = variables.VariableV1(0.)._ref() # pylint: disable=protected-access
i = constant_op.constant(0)
c = lambda i, x: math_ops.less(i, 10)
@@ -2340,7 +2329,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, body, [i, x], parallel_iterations=5)
- grad_ys = [variables.Variable(73)._ref()] # pylint: disable=protected-access
+ grad_ys = [variables.VariableV1(73)._ref()] # pylint: disable=protected-access
grad = gradients_impl.gradients([r[1]], [x], grad_ys=grad_ys)
variables.global_variables_initializer().run()
@@ -2354,6 +2343,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(0, value_x)
self.assertEqual(73, value_x_grad)
+ @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)")
def testWhileGrad_IndexedSlices(self):
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
@@ -2375,6 +2365,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r.values, values)[0]
self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
+ @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)")
def testWhileGrad_SparseTensor(self):
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
@@ -2397,6 +2388,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r.values, values)[0]
self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
+ @test_util.disable_control_flow_v2("b/115920078 (gradients)")
def testCallGradInLoop(self):
with self.cached_session() as sess:
i0 = constant_op.constant(0)
@@ -2416,6 +2408,8 @@ class ControlFlowTest(test.TestCase):
c, b, [i0, constant_op.constant(0.0)])
self.assertAllClose(600.0, sess.run(output_grad)[1])
+ @test_util.disable_control_flow_v2(
+ "b/116255781 (flat_args), b/115660901 (TensorArray)")
def testWhileAndTensorArray(self):
with self.cached_session() as sess:
param = constant_op.constant(2.0)
@@ -2520,6 +2514,7 @@ class ControlFlowTest(test.TestCase):
all_ops = x.graph.get_operations()
self.assertFalse(any([name in op.name for op in all_ops]))
+ @test_util.disable_control_flow_v2("b/116255781 (flat args)")
def testWhileGradGradFail(self):
theta = variables.Variable(initial_value=1.)
@@ -2549,6 +2544,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, y)[0]
self.assertEqual(388.0, r.eval())
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileGradientWithNontrainablePath1(self):
q = variables.Variable([7., 8.])
@@ -2566,6 +2562,7 @@ class ControlFlowTest(test.TestCase):
sess.run(q.initializer)
self.assertAllClose([0., 0.], sess.run(dy_dq))
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileGradientWithNontrainablePath2(self):
q = variables.Variable([7., 8.])
@@ -2583,6 +2580,7 @@ class ControlFlowTest(test.TestCase):
sess.run(q.initializer)
self.assertAllClose([1., 1.], sess.run(dy_dq))
+ @test_util.disable_control_flow_v2("b/115920078 (gradients)")
def testIssue16504(self):
c = constant_op.constant(np.arange(100), dtype=dtypes.float32)
w = variables.Variable(
@@ -2606,6 +2604,7 @@ class ControlFlowTest(test.TestCase):
grad, = gradients_impl.gradients(w, c)
self.assertIsNotNone(grad)
+ @test_util.disable_control_flow_v2("b/116270461 (resource)")
def testStopGradMultiFlows(self):
with self.cached_session():
@@ -2633,8 +2632,6 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(5.0, result.eval())
def testOneValueCond(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
@@ -2651,8 +2648,6 @@ class ControlFlowTest(test.TestCase):
self.assertEqual([2], i.eval(feed_dict={c: 0}))
def testExampleCond(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
with self.cached_session():
x = ops.convert_to_tensor([-2.0, 2.0], name="x")
@@ -2668,10 +2663,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(4.0, i.eval(feed_dict={d: 1}))
self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCase(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
with self.cached_session():
x = constant_op.constant(1)
y = constant_op.constant(2)
@@ -2723,10 +2717,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(r6.eval(), 0)
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCaseSideEffects(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
with self.cached_session() as sess:
v0 = variables.Variable(-1)
v1 = variables.Variable(-1)
@@ -2761,10 +2754,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(0, r0.eval())
self.assertAllEqual(sess.run([v0, v1, v2]), [0, -1, -1])
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testOneOpCond(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113324949 (ref vars)")
-
with self.cached_session():
v = variables.Variable(0)
c = ops.convert_to_tensor(0)
@@ -2794,7 +2785,7 @@ class ControlFlowTest(test.TestCase):
def testWithOpsDependencies(self):
with self.cached_session() as sess:
- v = variables.Variable(0.0)
+ v = variables.VariableV1(0.0)
c = constant_op.constant(10)
# Fetching v directly will result in an uninitialized error
@@ -2817,7 +2808,7 @@ class ControlFlowTest(test.TestCase):
def testWithTensorDependencies(self):
with self.cached_session():
- v = variables.Variable(0.0)
+ v = variables.VariableV1(0.0)
c1 = constant_op.constant(10)
c2 = constant_op.constant(20)
@@ -2843,7 +2834,7 @@ class ControlFlowTest(test.TestCase):
def testWithIndexedSlicesDependencies(self):
with self.cached_session():
- v = variables.Variable(
+ v = variables.VariableV1(
np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(np.float32))
v_at_1 = ops.IndexedSlices(v, constant_op.constant([1]))
gather_v_at_1 = array_ops.gather(v_at_1.values, v_at_1.indices)
@@ -2866,18 +2857,18 @@ class ControlFlowTest(test.TestCase):
with ops.Graph().as_default():
# device set on tensor => same device on dep.
with ops.device("/job:ps"):
- vd = variables.Variable([0.0])
+ vd = variables.VariableV1([0.0])
with_vd_dep = control_flow_ops.with_dependencies([vd.initializer], vd)
self.assertTrue("/job:ps" in with_vd_dep.device)
# No device set on tensor => no device on dep.
- vnod = variables.Variable([0.0])
+ vnod = variables.VariableV1([0.0])
with_vnod_dep = control_flow_ops.with_dependencies([vnod.initializer],
vnod)
self.assertDeviceEqual(None, with_vnod_dep.device)
# device set on tensor, default device on graph => default device on dep.
- vdef = variables.Variable([0.0], name="vdef")
+ vdef = variables.VariableV1([0.0], name="vdef")
with ops.device("/job:worker/device:GPU:1"):
with_vdef_dep = control_flow_ops.with_dependencies([vdef.initializer],
vdef)
@@ -2887,8 +2878,8 @@ class ControlFlowTest(test.TestCase):
def testGroup(self):
with self.cached_session() as sess:
- v1 = variables.Variable([0.0])
- v2 = variables.Variable([1.0])
+ v1 = variables.VariableV1([0.0])
+ v2 = variables.VariableV1([1.0])
# Group init1 and init2 and run.
init = control_flow_ops.group(v1.initializer, v2.initializer)
@@ -2970,29 +2961,29 @@ class ControlFlowTest(test.TestCase):
p1 = array_ops.placeholder(dtypes.float32)
p2 = array_ops.placeholder(dtypes.float32)
p3 = array_ops.placeholder(dtypes.float32)
- v1 = variables.Variable(p1, validate_shape=False)
- v2 = variables.Variable(p2, validate_shape=False)
- v3 = variables.Variable(p3, validate_shape=False)
+ v1 = variables.VariableV1(p1, validate_shape=False)
+ v2 = variables.VariableV1(p2, validate_shape=False)
+ v3 = variables.VariableV1(p3, validate_shape=False)
self.assertIs(None, v1.get_shape().ndims)
s = control_flow_ops.ref_select(index, [v1, v2, v3])
self.assertIs(None, s.get_shape().ndims)
# All inputs known but different.
- v1 = variables.Variable([[1, 2]])
- v2 = variables.Variable([[2], [1]])
+ v1 = variables.VariableV1([[1, 2]])
+ v2 = variables.VariableV1([[2], [1]])
s = control_flow_ops.ref_select(index, [v1, v2])
self.assertIs(None, s.get_shape().ndims)
# All inputs known and same.
- v1 = variables.Variable([[1, 2]])
- v2 = variables.Variable([[1, 2]])
+ v1 = variables.VariableV1([[1, 2]])
+ v2 = variables.VariableV1([[1, 2]])
s = control_flow_ops.ref_select(index, [v1, v2])
self.assertEqual([1, 2], s.get_shape())
# Possibly the same but not guaranteed.
- v1 = variables.Variable([[1., 2.]])
+ v1 = variables.VariableV1([[1., 2.]])
p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
- v2 = variables.Variable(p2, validate_shape=False)
+ v2 = variables.VariableV1(p2, validate_shape=False)
s = control_flow_ops.ref_select(index, [v1, v2])
self.assertEqual(None, s.get_shape())
@@ -3046,9 +3037,11 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, x)[0]
self.assertEqual(r.eval(), 524288.0)
- self.assertEqual(
- len([op for op in x.graph.get_operations() if op.type == "StackV2"]),
- 1)
+ # while_v2 does not have stacks.
+ if not control_flow_ops.ENABLE_WHILE_V2:
+ self.assertEqual(
+ len([op for op in x.graph.get_operations() if op.type == "StackV2"
+ ]), 1)
class ControlFlowContextCheckTest(test.TestCase):
@@ -3175,11 +3168,11 @@ class TupleTest(test.TestCase):
def testTensors(self):
for v1_first in [True, False]:
with self.cached_session():
- v1 = variables.Variable([1.0])
+ v1 = variables.VariableV1([1.0])
add1 = math_ops.add(
control_flow_ops.with_dependencies([v1.initializer], v1._ref()), # pylint: disable=protected-access
2.0)
- v2 = variables.Variable([10.0])
+ v2 = variables.VariableV1([10.0])
add2 = math_ops.add(
control_flow_ops.with_dependencies([v2.initializer], v2._ref()), # pylint: disable=protected-access
20.0)
@@ -3205,14 +3198,14 @@ class TupleTest(test.TestCase):
def testIndexedSlices(self):
for v1_first in [True, False]:
with self.cached_session():
- v1 = variables.Variable(
+ v1 = variables.VariableV1(
np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(
np.float32))
v1_at_1 = ops.IndexedSlices(
control_flow_ops.with_dependencies([v1.initializer], v1._ref()), # pylint: disable=protected-access
constant_op.constant([1]))
- v2 = variables.Variable(
+ v2 = variables.VariableV1(
np.array([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]]).astype(
np.float32))
v2_at_1 = ops.IndexedSlices(
@@ -3244,7 +3237,7 @@ class TupleTest(test.TestCase):
def testAcceptTensorsAsControlInputs(self):
with self.cached_session():
- var = variables.Variable(0)
+ var = variables.VariableV1(0)
assign = state_ops.assign(var, 1)
t, = control_flow_ops.tuple(
[constant_op.constant(0)], control_inputs=[assign])
@@ -3408,7 +3401,7 @@ class WhileOpBenchmark(test.Benchmark):
name="unroll_same_device", iters=iters, wall_time=duration)
-@test_util.with_cond_v2
+@test_util.with_control_flow_v2
class EagerTest(test.TestCase):
def testCond(self):
diff --git a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py
new file mode 100644
index 0000000000..8028f93a8c
--- /dev/null
+++ b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py
@@ -0,0 +1,878 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for binary coefficient-wise operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+_ADD = lambda x, y: x + y
+_SUB = lambda x, y: x - y
+_MUL = lambda x, y: x * y
+_POW = lambda x, y: x**y
+_TRUEDIV = lambda x, y: x / y
+_FLOORDIV = lambda x, y: x // y
+_MOD = lambda x, y: x % y
+
+
+# TODO(zongheng): it'd be great to factor out this function and various random
+# SparseTensor gen funcs.
+def _sparsify(x, thresh=0.5, index_dtype=np.int64):
+ x[x < thresh] = 0
+
+ non_zero = np.where(x)
+ x_indices = np.vstack(non_zero).astype(index_dtype).T
+ x_values = x[non_zero]
+ x_shape = x.shape
+
+ return sparse_tensor.SparseTensor(
+ indices=x_indices, values=x_values, dense_shape=x_shape), x_values
+
+
+def _default_tolerance(dtype):
+ """Returns a sensible default tolerance for comparing results of a given type.
+
+ Args:
+ dtype: A datatype.
+ """
+ if dtype == np.float16:
+ return 5e-3
+ elif dtype in (np.float32, np.complex64):
+ return 1e-3
+ elif dtype in (np.float64, np.complex128):
+ return 1e-5
+ else:
+ return None # Fail fast for unexpected types
+
+
+class BinaryOpTest(test.TestCase):
+
+ def _compareCpu(self, x, y, np_func, tf_func, also_compare_variables=False):
+ np_ans = np_func(x, y)
+ with self.test_session(use_gpu=False):
+ inx = ops.convert_to_tensor(x)
+ iny = ops.convert_to_tensor(y)
+ out = tf_func(inx, iny)
+ tf_cpu = out.eval()
+ # Test that the op takes precedence over numpy operators.
+ np_left = tf_func(x, iny).eval()
+ np_right = tf_func(inx, y).eval()
+
+ if also_compare_variables:
+ var_x = variables.Variable(x)
+ var_y = variables.Variable(y)
+ variables.global_variables_initializer().run()
+ print(type(x), type(y), type(var_x), type(var_y))
+ print(type(tf_func(x, var_y)), type(tf_func(var_x, y)))
+ np_var_left = tf_func(x, var_y).eval()
+ np_var_right = tf_func(var_x, y).eval()
+
+ if np_ans.dtype != np.object:
+ self.assertAllClose(np_ans, tf_cpu)
+ self.assertAllClose(np_ans, np_left)
+ self.assertAllClose(np_ans, np_right)
+ if also_compare_variables:
+ self.assertAllClose(np_ans, np_var_left)
+ self.assertAllClose(np_ans, np_var_right)
+ self.assertShapeEqual(np_ans, out)
+
+ _GRAD_TOL = {
+ dtypes_lib.float16: 1e-3,
+ dtypes_lib.float32: 1e-3,
+ dtypes_lib.complex64: 1e-2,
+ dtypes_lib.float64: 1e-5,
+ dtypes_lib.complex128: 1e-4
+ }
+
+ def _compareGradientX(self,
+ x,
+ y,
+ np_func,
+ tf_func,
+ numeric_gradient_type=None):
+ z = np_func(x, y)
+ zs = list(z.shape)
+ with self.cached_session():
+ inx = ops.convert_to_tensor(x)
+ iny = ops.convert_to_tensor(y)
+ if x.dtype in (np.float32, np.float64):
+ out = 1.1 * tf_func(inx, iny)
+ else:
+ out = tf_func(inx, iny)
+ xs = list(x.shape)
+ jacob_t, jacob_n = gradient_checker.compute_gradient(
+ inx, xs, out, zs, x_init_value=x)
+ if numeric_gradient_type is not None:
+ xf = x.astype(numeric_gradient_type)
+ yf = y.astype(numeric_gradient_type)
+ inxf = ops.convert_to_tensor(xf)
+ inyf = ops.convert_to_tensor(yf)
+ outf = tf_func(inxf, inyf)
+ _, jacob_n = gradient_checker.compute_gradient(
+ inxf, xs, outf, zs, x_init_value=xf, delta=1e-3)
+ jacob_n = jacob_n.astype(x.dtype)
+ tol = self._GRAD_TOL[dtypes_lib.as_dtype(x.dtype)]
+ self.assertAllClose(jacob_t, jacob_n, rtol=tol, atol=tol)
+
+ def _compareGradientY(self,
+ x,
+ y,
+ np_func,
+ tf_func,
+ numeric_gradient_type=None):
+ z = np_func(x, y)
+ zs = list(z.shape)
+ with self.cached_session():
+ inx = ops.convert_to_tensor(x)
+ iny = ops.convert_to_tensor(y)
+ if x.dtype in (np.float32, np.float64):
+ out = 1.1 * tf_func(inx, iny)
+ else:
+ out = tf_func(inx, iny)
+ ys = list(np.shape(y))
+ jacob_t, jacob_n = gradient_checker.compute_gradient(
+ iny, ys, out, zs, x_init_value=y)
+ if numeric_gradient_type is not None:
+ xf = x.astype(numeric_gradient_type)
+ yf = y.astype(numeric_gradient_type)
+ inxf = ops.convert_to_tensor(xf)
+ inyf = ops.convert_to_tensor(yf)
+ outf = tf_func(inxf, inyf)
+ _, jacob_n = gradient_checker.compute_gradient(
+ inyf, ys, outf, zs, x_init_value=yf)
+ jacob_n = jacob_n.astype(x.dtype)
+ tol = self._GRAD_TOL[dtypes_lib.as_dtype(x.dtype)]
+ self.assertAllClose(jacob_t, jacob_n, rtol=tol, atol=tol)
+
+ def _compareGpu(self, x, y, np_func, tf_func):
+ np_ans = np_func(x, y)
+ with self.test_session(force_gpu=test_util.is_gpu_available()):
+ inx = ops.convert_to_tensor(x)
+ iny = ops.convert_to_tensor(y)
+ out = tf_func(inx, iny)
+ tf_gpu = out.eval()
+ self.assertAllClose(np_ans, tf_gpu)
+ self.assertShapeEqual(np_ans, out)
+ # TODO(zhifengc/ke): make gradient checker work on GPU.
+
+ def _compareBoth(self, x, y, np_func, tf_func, also_compare_variables=False):
+ self._compareCpu(x, y, np_func, tf_func, also_compare_variables)
+ if x.dtype in (np.float16, np.float32, np.float64, np.complex64,
+ np.complex128):
+ if tf_func not in (_FLOORDIV, math_ops.floordiv, math_ops.zeta,
+ math_ops.polygamma):
+ self._compareGradientX(x, y, np_func, tf_func)
+ self._compareGradientY(x, y, np_func, tf_func)
+ if tf_func in (math_ops.zeta, math_ops.polygamma):
+ # These methods only support gradients in the second parameter
+ self._compareGradientY(x, y, np_func, tf_func)
+ self._compareGpu(x, y, np_func, tf_func)
+
+ def testFloatBasic(self):
+ x = np.linspace(-5, 20, 15).reshape(1, 3, 5).astype(np.float32)
+ y = np.linspace(20, -5, 15).reshape(1, 3, 5).astype(np.float32)
+ self._compareBoth(x, y, np.add, math_ops.add, also_compare_variables=True)
+ self._compareBoth(x, y, np.subtract, math_ops.subtract)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
+ self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
+ self._compareBoth(x, y + 0.1, np.floor_divide, math_ops.floordiv)
+ self._compareBoth(x, y, np.add, _ADD)
+ self._compareBoth(x, y, np.subtract, _SUB)
+ self._compareBoth(x, y, np.multiply, _MUL)
+ self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
+ self._compareBoth(x, y + 0.1, np.floor_divide, _FLOORDIV)
+ self._compareBoth(x, y, np.arctan2, math_ops.atan2)
+ x1 = np.random.randn(5, 6).astype(np.float32)
+ x2 = np.random.randn(5, 6).astype(np.float32)
+ # Remove tiny values--atan2 gradients are flaky near the origin.
+ x1[np.abs(x1) < 0.05] = 0.05 * np.sign(x1[np.abs(x1) < 0.05])
+ x2[np.abs(x2) < 0.05] = 0.05 * np.sign(x2[np.abs(x2) < 0.05])
+ self._compareBoth(x1, x2, np.arctan2, math_ops.atan2)
+ try:
+ from scipy import special # pylint: disable=g-import-not-at-top
+ a_pos_small = np.linspace(0.1, 2, 15).reshape(1, 3, 5).astype(np.float32)
+ x_pos_small = np.linspace(0.1, 10, 15).reshape(1, 3, 5).astype(np.float32)
+ self._compareBoth(a_pos_small, x_pos_small, special.gammainc,
+ math_ops.igamma)
+ self._compareBoth(a_pos_small, x_pos_small, special.gammaincc,
+ math_ops.igammac)
+ # Need x > 1
+ self._compareBoth(x_pos_small + 1, a_pos_small, special.zeta,
+ math_ops.zeta)
+ n_small = np.arange(0, 15).reshape(1, 3, 5).astype(np.float32)
+ self._compareBoth(n_small, x_pos_small, special.polygamma,
+ math_ops.polygamma)
+ except ImportError as e:
+ tf_logging.warn("Cannot test special functions: %s" % str(e))
+
+ def testFloatDifferentShapes(self):
+ x = np.array([1, 2, 3, 4]).reshape(2, 2).astype(np.float32)
+ y = np.array([1, 2]).reshape(2, 1).astype(np.float32)
+ with self.cached_session() as sess:
+ inx = ops.convert_to_tensor(x)
+ iny = ops.convert_to_tensor(y)
+ s = math_ops.reduce_sum(inx * iny)
+ gx, gy = sess.run(gradients_impl.gradients(s, [inx, iny]))
+ # gx is simply the broadcasted y
+ self.assertAllEqual(gx,
+ np.array([1, 1, 2, 2]).reshape(2, 2).astype(np.float32))
+ # gy is x's column summed up
+ self.assertAllEqual(gy, np.array([3, 7]).reshape(2, 1).astype(np.float32))
+
+ def testFloatVariableOverload(self):
+ x = np.array([1, 2, 3, 4]).reshape(2, 2).astype(np.int32)
+ y = np.array([1, 2]).reshape(2, 1).astype(np.int32)
+ var_x = variables.Variable(x)
+ var_y = variables.Variable(y)
+ with self.cached_session() as sess:
+ sess.run([var_x.initializer, var_y.initializer])
+ left_result = (var_x * y).eval()
+ right_result = (x * var_y).eval()
+ np_result = x * y
+ self.assertAllEqual(np_result, left_result)
+ self.assertAllEqual(np_result, right_result)
+
+ def testDoubleBasic(self):
+ x = np.linspace(-5, 20, 15).reshape(1, 3, 5).astype(np.float64)
+ y = np.linspace(20, -5, 15).reshape(1, 3, 5).astype(np.float64)
+ self._compareBoth(x, y, np.add, math_ops.add)
+ self._compareBoth(x, y, np.subtract, math_ops.subtract)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
+ self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
+ self._compareBoth(x, y + 0.1, np.floor_divide, math_ops.floordiv)
+ self._compareBoth(x, y, np.add, _ADD)
+ self._compareBoth(x, y, np.subtract, _SUB)
+ self._compareBoth(x, y, np.multiply, _MUL)
+ self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
+ self._compareBoth(x, y + 0.1, np.floor_divide, _FLOORDIV)
+ self._compareBoth(x, y, np.arctan2, math_ops.atan2)
+ x1 = np.random.randn(7, 4).astype(np.float64)
+ x2 = np.random.randn(7, 4).astype(np.float64)
+ # Remove tiny values--atan2 gradients are flaky near the origin.
+ x1[np.abs(x1) < 0.5] = 0.5 * np.sign(x1[np.abs(x1) < 0.5])
+ x2[np.abs(x2) < 0.5] = 0.5 * np.sign(x2[np.abs(x2) < 0.5])
+ self._compareBoth(x1, x2, np.arctan2, math_ops.atan2)
+ try:
+ from scipy import special # pylint: disable=g-import-not-at-top
+ a_pos_small = np.linspace(0.1, 2, 15).reshape(1, 3, 5).astype(np.float32)
+ x_pos_small = np.linspace(0.1, 10, 15).reshape(1, 3, 5).astype(np.float32)
+ self._compareBoth(a_pos_small, x_pos_small, special.gammainc,
+ math_ops.igamma)
+ self._compareBoth(a_pos_small, x_pos_small, special.gammaincc,
+ math_ops.igammac)
+ except ImportError as e:
+ tf_logging.warn("Cannot test special functions: %s" % str(e))
+
+ def testUint8Basic(self):
+ x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.uint8)
+ y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.uint8)
+ self._compareBoth(x, y, np.add, math_ops.add)
+
+ def testInt8Basic(self):
+ x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int8)
+ y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int8)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
+ self._compareBoth(x, y, np.multiply, _MUL)
+
+ def testInt16Basic(self):
+ x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int16)
+ y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int16)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
+ self._compareBoth(x, y, np.multiply, _MUL)
+
+ def testUint16Basic(self):
+ x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.uint16)
+ y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.uint16)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
+ self._compareBoth(x, y, np.multiply, _MUL)
+ self._compareBoth(x, y, np.true_divide, math_ops.truediv)
+ self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
+ self._compareBoth(x, y, np.true_divide, _TRUEDIV)
+ self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
+
+ def testInt32Basic(self):
+ x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int32)
+ y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int32)
+ self._compareBoth(x, y, np.add, math_ops.add)
+ self._compareBoth(x, y, np.subtract, math_ops.subtract)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
+ self._compareBoth(x, y, np.true_divide, math_ops.truediv)
+ self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
+ self._compareBoth(x, y, np.mod, math_ops.mod)
+ self._compareBoth(x, y, np.add, _ADD)
+ self._compareBoth(x, y, np.subtract, _SUB)
+ self._compareBoth(x, y, np.multiply, _MUL)
+ self._compareBoth(x, y, np.true_divide, _TRUEDIV)
+ self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
+ self._compareBoth(x, y, np.mod, _MOD)
+ # _compareBoth tests on GPU only for floating point types, so test
+ # _MOD for int32 on GPU by calling _compareGpu
+ self._compareGpu(x, y, np.mod, _MOD)
+
+ def testInt64Basic(self):
+ x = np.arange(1 << 40, 13 << 40, 2 << 40).reshape(1, 3, 2).astype(np.int64)
+ y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int64)
+ self._compareBoth(x, y, np.subtract, math_ops.subtract)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
+ self._compareBoth(x, y, np.true_divide, math_ops.truediv)
+ self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
+ self._compareBoth(x, y, np.mod, math_ops.mod)
+ self._compareBoth(x, y, np.subtract, _SUB)
+ self._compareBoth(x, y, np.multiply, _MUL)
+ self._compareBoth(x, y, np.true_divide, _TRUEDIV)
+ self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
+ self._compareBoth(x, y, np.mod, _MOD)
+
+ def testComplex64Basic(self):
+ x = np.complex(1, 1) * np.linspace(-10, 10, 6).reshape(1, 3, 2).astype(
+ np.complex64)
+ y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape(1, 3, 2).astype(
+ np.complex64)
+ self._compareBoth(x, y, np.add, math_ops.add)
+ self._compareBoth(x, y, np.subtract, math_ops.subtract)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
+ self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
+ self._compareBoth(x, y, np.add, _ADD)
+ self._compareBoth(x, y, np.subtract, _SUB)
+ self._compareBoth(x, y, np.multiply, _MUL)
+ self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
+
+ def testComplex128Basic(self):
+ x = np.complex(1, 1) * np.linspace(-10, 10, 6).reshape(1, 3, 2).astype(
+ np.complex128)
+ y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape(1, 3, 2).astype(
+ np.complex128)
+ self._compareBoth(x, y, np.add, math_ops.add)
+ self._compareBoth(x, y, np.subtract, math_ops.subtract)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
+ self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
+ self._compareBoth(x, y, np.add, _ADD)
+ self._compareBoth(x, y, np.subtract, _SUB)
+ self._compareBoth(x, y, np.multiply, _MUL)
+ self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
+
+ def testStringComparison(self):
+ x = np.array([["abc", "bh"], ["c", ""]])
+ y = np.array([["abc", "bh"], ["def", "hi"]])
+ with self.test_session(use_gpu=False) as sess:
+ cmp_eq = math_ops.equal(x, y)
+ cmp_not_eq = math_ops.not_equal(x, y)
+ values = sess.run([cmp_eq, cmp_not_eq])
+ self.assertAllEqual([[True, True], [False, False]], values[0])
+ self.assertAllEqual([[False, False], [True, True]], values[1])
+
+ def testString(self):
+ x = np.array([["x_0_0", "x_0_1", "x_0_2"], ["x_1_0", "x_1_1", "x_1_2"],
+ ["x_2_0", "x_2_1", "x_2_2"]],
+ dtype=np.object)
+ y = np.array([["y_0_0", "y_0_1", "y_0_2"], ["y_1_0", "y_1_1", "y_1_2"],
+ ["y_2_0", "y_2_1", "y_2_2"]],
+ dtype=np.object)
+ z = np.array([["z_0", "z_1", "z_2"]], dtype=np.object)
+ w = np.array("w", dtype=np.object)
+ self._compareCpu(x, y, _ADD, _ADD)
+ self._compareCpu(x, z, _ADD, _ADD)
+ self._compareCpu(x, w, _ADD, _ADD)
+ self._compareCpu(z, w, _ADD, _ADD)
+
+ def _compareBCast(self, xs, ys, dtype, np_func, tf_func):
+ if dtype in (np.complex64, np.complex128):
+ x = (1 + np.linspace(0, 2 + 3j, np.prod(xs))).astype(dtype).reshape(xs)
+ y = (1 + np.linspace(0, 2 - 2j, np.prod(ys))).astype(dtype).reshape(ys)
+ else:
+ x = (1 + np.linspace(0, 5, np.prod(xs))).astype(dtype).reshape(xs)
+ y = (1 + np.linspace(0, 5, np.prod(ys))).astype(dtype).reshape(ys)
+ self._compareCpu(x, y, np_func, tf_func)
+ if x.dtype in (np.float16, np.float32, np.float64):
+ # TODO(aselle): Make the test work for dtypes:
+ # (np.complex64, np.complex128).
+ if tf_func not in (_FLOORDIV, math_ops.floordiv):
+ if x.dtype == np.float16:
+ # Compare fp16 theoretical gradients to fp32 numerical gradients,
+ # since fp16 numerical gradients are too imprecise unless great
+ # care is taken with choosing the inputs and the delta. This is
+ # a weaker check (in particular, it does not test the op itself,
+ # only its gradient), but it's much better than nothing.
+ self._compareGradientX(x, y, np_func, tf_func, np.float)
+ self._compareGradientY(x, y, np_func, tf_func, np.float)
+ else:
+ self._compareGradientX(x, y, np_func, tf_func)
+ self._compareGradientY(x, y, np_func, tf_func)
+ self._compareGpu(x, y, np_func, tf_func)
+
+ # TODO(josh11b,vrv): Refactor this to use parameterized tests.
+ def _testBCastByFunc(self, funcs, xs, ys):
+ dtypes = [
+ np.float16,
+ np.float32,
+ np.float64,
+ np.int32,
+ np.int64,
+ np.complex64,
+ np.complex128,
+ ]
+ for dtype in dtypes:
+ for (np_func, tf_func) in funcs:
+ if (dtype in (np.complex64, np.complex128) and
+ tf_func in (_FLOORDIV, math_ops.floordiv)):
+ continue # floordiv makes no sense for complex numbers
+ self._compareBCast(xs, ys, dtype, np_func, tf_func)
+ self._compareBCast(ys, xs, dtype, np_func, tf_func)
+
+ def _testBCastA(self, xs, ys):
+ funcs = [
+ (np.add, math_ops.add),
+ (np.add, _ADD),
+ ]
+ self._testBCastByFunc(funcs, xs, ys)
+
+ def _testBCastB(self, xs, ys):
+ funcs = [
+ (np.subtract, math_ops.subtract),
+ (np.subtract, _SUB),
+ (np.power, math_ops.pow),
+ ]
+ self._testBCastByFunc(funcs, xs, ys)
+
+ def _testBCastC(self, xs, ys):
+ funcs = [
+ (np.multiply, math_ops.multiply),
+ (np.multiply, _MUL),
+ ]
+ self._testBCastByFunc(funcs, xs, ys)
+
+ def _testBCastD(self, xs, ys):
+ funcs = [
+ (np.true_divide, math_ops.truediv),
+ (np.floor_divide, math_ops.floordiv),
+ (np.true_divide, _TRUEDIV),
+ (np.floor_divide, _FLOORDIV),
+ ]
+ self._testBCastByFunc(funcs, xs, ys)
+
+ def testBCast_0A(self):
+ self._testBCastA([1, 3, 2], [1])
+
+ def testBCast_0B(self):
+ self._testBCastB([1, 3, 2], [1])
+
+ def testBCast_0C(self):
+ self._testBCastC([1, 3, 2], [1])
+
+ def testBCast_0D(self):
+ self._testBCastD([1, 3, 2], [1])
+
+ def testBCast_1A(self):
+ self._testBCastA([1, 3, 2], [2])
+
+ def testBCast_1B(self):
+ self._testBCastB([1, 3, 2], [2])
+
+ def testBCast_1C(self):
+ self._testBCastC([1, 3, 2], [2])
+
+ def testBCast_1D(self):
+ self._testBCastD([1, 3, 2], [2])
+
+ def testBCast_2A(self):
+ self._testBCastA([1, 3, 2], [3, 2])
+
+ def testBCast_2B(self):
+ self._testBCastB([1, 3, 2], [3, 2])
+
+ def testBCast_2C(self):
+ self._testBCastC([1, 3, 2], [3, 2])
+
+ def testBCast_2D(self):
+ self._testBCastD([1, 3, 2], [3, 2])
+
+ def testBCast_3A(self):
+ self._testBCastA([1, 3, 2], [3, 1])
+
+ def testBCast_3B(self):
+ self._testBCastB([1, 3, 2], [3, 1])
+
+ def testBCast_3C(self):
+ self._testBCastC([1, 3, 2], [3, 1])
+
+ def testBCast_3D(self):
+ self._testBCastD([1, 3, 2], [3, 1])
+
+ def testBCast_4A(self):
+ self._testBCastA([1, 3, 2], [1, 3, 2])
+
+ def testBCast_4B(self):
+ self._testBCastB([1, 3, 2], [1, 3, 2])
+
+ def testBCast_4C(self):
+ self._testBCastC([1, 3, 2], [1, 3, 2])
+
+ def testBCast_4D(self):
+ self._testBCastD([1, 3, 2], [1, 3, 2])
+
+ def testBCast_5A(self):
+ self._testBCastA([1, 3, 2], [2, 3, 1])
+
+ def testBCast_5B(self):
+ self._testBCastB([1, 3, 2], [2, 3, 1])
+
+ def testBCast_5C(self):
+ self._testBCastC([1, 3, 2], [2, 3, 1])
+
+ def testBCast_5D(self):
+ self._testBCastD([1, 3, 2], [2, 3, 1])
+
+ def testBCast_6A(self):
+ self._testBCastA([1, 3, 2], [2, 1, 1])
+
+ def testBCast_6B(self):
+ self._testBCastB([1, 3, 2], [2, 1, 1])
+
+ def testBCast_6C(self):
+ self._testBCastC([1, 3, 2], [2, 1, 1])
+
+ def testBCast_6D(self):
+ self._testBCastD([1, 3, 2], [2, 1, 1])
+
+ def testBCast_7A(self):
+ self._testBCastA([1, 3, 2], [1, 3, 1])
+
+ def testBCast_7B(self):
+ self._testBCastB([1, 3, 2], [1, 3, 1])
+
+ def testBCast_7C(self):
+ self._testBCastC([1, 3, 2], [1, 3, 1])
+
+ def testBCast_7D(self):
+ self._testBCastD([1, 3, 2], [1, 3, 1])
+
+ def testBCast_8A(self):
+ self._testBCastA([2, 1, 5], [2, 3, 1])
+
+ def testBCast_8B(self):
+ self._testBCastB([2, 1, 5], [2, 3, 1])
+
+ def testBCast_8C(self):
+ self._testBCastC([2, 1, 5], [2, 3, 1])
+
+ def testBCast_8D(self):
+ self._testBCastD([2, 1, 5], [2, 3, 1])
+
+ def testBCast_9A(self):
+ self._testBCastA([2, 0, 5], [2, 0, 1])
+
+ def testBCast_9B(self):
+ self._testBCastB([2, 0, 5], [2, 0, 1])
+
+ def testBCast_9C(self):
+ self._testBCastC([2, 0, 5], [2, 0, 1])
+
+ def testBCast_9D(self):
+ self._testBCastD([2, 0, 5], [2, 0, 1])
+
+ def testBCast_10A(self):
+ self._testBCastA([2, 3, 0], [2, 3, 1])
+
+ def testBCast_10B(self):
+ self._testBCastB([2, 3, 0], [2, 3, 1])
+
+ def testBCast_10C(self):
+ self._testBCastC([2, 3, 0], [2, 3, 1])
+
+ def testBCast_10D(self):
+ self._testBCastD([2, 3, 0], [2, 3, 1])
+
+ def testBCast_11A(self):
+ self._testBCastA([1, 3, 2], [1, 3, 2])
+
+ def testBCast_11B(self):
+ self._testBCastB([1, 3, 2], [1, 3, 2])
+
+ def testBCast_11C(self):
+ self._testBCastC([1, 3, 2], [1, 3, 2])
+
+ def testBCast_11D(self):
+ self._testBCastD([1, 3, 2], [1, 3, 2])
+
+ def testBCast_12A(self):
+ self._testBCastA([1, 1, 1, 1, 3, 2], [1, 3, 2])
+
+ def testBCast_12B(self):
+ self._testBCastB([1, 1, 1, 1, 3, 2], [1, 3, 2])
+
+ def testBCast_12C(self):
+ self._testBCastC([1, 1, 1, 1, 3, 2], [1, 3, 2])
+
+ def testBCast_12D(self):
+ self._testBCastD([1, 1, 1, 1, 3, 2], [1, 3, 2])
+
+ def testBCast_13A(self):
+ self._testBCastA([1, 3, 2, 1, 1], [1])
+
+ def testBCast_13B(self):
+ self._testBCastB([1, 3, 2, 1, 1], [1])
+
+ def testBCast_13C(self):
+ self._testBCastC([1, 3, 2, 1, 1], [1])
+
+ def testBCast_13D(self):
+ self._testBCastD([1, 3, 2, 1, 1], [1])
+
+ def testBCast_14A(self):
+ self._testBCastA([2, 3, 1, 1, 5], [1])
+
+ def testBCast_14B(self):
+ self._testBCastB([2, 3, 1, 1, 5], [1])
+
+ def testBCast_14C(self):
+ self._testBCastC([2, 3, 1, 1, 5], [1])
+
+ def testBCast_14D(self):
+ self._testBCastD([2, 3, 1, 1, 5], [1])
+
+ def testBCast_15A(self):
+ self._testBCastA([10, 3, 1, 2], [3, 1, 2])
+
+ def testBCast_15B(self):
+ self._testBCastB([10, 3, 1, 2], [3, 1, 2])
+
+ def testBCast_15C(self):
+ self._testBCastC([10, 3, 1, 2], [3, 1, 2])
+
+ def testBCast_15D(self):
+ self._testBCastD([10, 3, 1, 2], [3, 1, 2])
+
+ def testMismatchedDimensions(self):
+ for func in [
+ math_ops.add, math_ops.subtract, math_ops.multiply, math_ops.div, _ADD,
+ _SUB, _MUL, _TRUEDIV, _FLOORDIV
+ ]:
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, lambda e: "Dimensions must" in str(e)):
+ func(
+ ops.convert_to_tensor([10.0, 20.0, 30.0]),
+ ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]]))
+
+ def testZeroPowGrad(self):
+ with self.cached_session():
+ for dtype in (np.float16, np.float32, np.float64, np.complex64,
+ np.complex128):
+ x = constant_op.constant(0.0, dtype=dtype)
+ y = constant_op.constant(2.0, dtype=dtype)
+ z = math_ops.pow(x, y)
+ error = gradient_checker.compute_gradient_error(y, [], z, [])
+ self.assertEqual(error, 0)
+
+ def testComplexPowGrad(self):
+ with self.cached_session():
+ for dtype in np.complex64, np.complex128:
+ for base in 2.0, -2.0:
+ x = constant_op.constant(base, dtype=dtype)
+ y = constant_op.constant(2.0, dtype=dtype)
+ z = math_ops.pow(x, y)
+ error = gradient_checker.compute_gradient_error(y, [], z, [])
+ self.assertLess(error, 2e-4)
+
+ def testAtan2SpecialValues(self):
+ x1l, x2l = zip((+0.0, +0.0), (+0.0, -0.0), (-0.0, +0.0), (-0.0, -0.0),
+ (1.2345, float("inf")), (1.2345, -float("inf")),
+ (-4.321, float("inf")), (-4.125, -float("inf")),
+ (float("inf"), float("inf")), (float("inf"), -float("inf")),
+ (-float("inf"), float("inf")),
+ (-float("inf"), -float("inf")))
+ for dtype in np.float32, np.float64:
+ x1 = np.array(x1l).astype(dtype)
+ x2 = np.array(x2l).astype(dtype)
+ self._compareCpu(x1, x2, np.arctan2, math_ops.atan2)
+ self._compareGpu(x1, x2, np.arctan2, math_ops.atan2)
+
+ def testPowNegativeExponent(self):
+ for dtype in [np.int32, np.int64]:
+ with self.test_session(use_gpu=False) as sess:
+ with self.assertRaisesRegexp(
+ errors_impl.InvalidArgumentError,
+ "Integers to negative integer powers are not allowed"):
+ x = np.array([5, 2]).astype(dtype)
+ y = np.array([-2, 3]).astype(dtype)
+ sess.run(math_ops.pow(x, y))
+
+ with self.test_session(use_gpu=False) as sess:
+ with self.assertRaisesRegexp(
+ errors_impl.InvalidArgumentError,
+ "Integers to negative integer powers are not allowed"):
+ x = np.array([5, 2]).astype(dtype)
+ y = np.array([2, -3]).astype(dtype)
+ sess.run(math_ops.pow(x, y))
+
+ with self.test_session(use_gpu=False) as sess:
+ with self.assertRaisesRegexp(
+ errors_impl.InvalidArgumentError,
+ "Integers to negative integer powers are not allowed"):
+ x = np.array([5, 2]).astype(dtype)
+ y = -3
+ sess.run(math_ops.pow(x, y))
+
+
+class ComparisonOpTest(test.TestCase):
+
+ def _compareScalar(self, func, x, y, dtype):
+ with self.test_session(force_gpu=test_util.is_gpu_available()):
+ out = func(
+ ops.convert_to_tensor(np.array([x]).astype(dtype)),
+ ops.convert_to_tensor(np.array([y]).astype(dtype)))
+ ret = out.eval()
+ return ret[0]
+
+ def testScalarCompareScalar(self):
+ dtypes = [np.float16, np.float32, np.float64, np.int32, np.int64]
+ data = [-1, 0, 1]
+ for t in dtypes:
+ for x in data:
+ for y in data:
+ self.assertEqual(self._compareScalar(math_ops.less, x, y, t), x < y)
+ self.assertEqual(
+ self._compareScalar(math_ops.less_equal, x, y, t), x <= y)
+ self.assertEqual(
+ self._compareScalar(math_ops.greater, x, y, t), x > y)
+ self.assertEqual(
+ self._compareScalar(math_ops.greater_equal, x, y, t), x >= y)
+ self.assertEqual(self._compareScalar(math_ops.equal, x, y, t), x == y)
+ self.assertEqual(
+ self._compareScalar(math_ops.not_equal, x, y, t), x != y)
+ data = [-1, 0, 1, -1j, 1j, 1 + 1j, 1 - 1j]
+ for t in [np.complex64, np.complex128]:
+ for x in data:
+ for y in data:
+ self.assertEqual(self._compareScalar(math_ops.equal, x, y, t), x == y)
+ self.assertEqual(
+ self._compareScalar(math_ops.not_equal, x, y, t), x != y)
+
+ def _compare(self, x, y, np_func, tf_func):
+ np_ans = np_func(x, y)
+ with self.test_session(force_gpu=test_util.is_gpu_available()):
+ out = tf_func(ops.convert_to_tensor(x), ops.convert_to_tensor(y))
+ tf_ans = out.eval()
+ self.assertAllEqual(np_ans, tf_ans)
+
+ def testTensorCompareTensor(self):
+ x = np.linspace(-15, 15, 6).reshape(1, 3, 2)
+ y = np.linspace(20, -10, 6).reshape(1, 3, 2)
+ for t in [np.float16, np.float32, np.float64, np.int32, np.int64]:
+ xt = x.astype(t)
+ yt = y.astype(t)
+ self._compare(xt, yt, np.less, math_ops.less)
+ self._compare(xt, yt, np.less_equal, math_ops.less_equal)
+ self._compare(xt, yt, np.greater, math_ops.greater)
+ self._compare(xt, yt, np.greater_equal, math_ops.greater_equal)
+ self._compare(xt, yt, np.equal, math_ops.equal)
+ self._compare(xt, yt, np.not_equal, math_ops.not_equal)
+ # Complex types do not support ordering but do support equality tests.
+ for t in [np.complex64, np.complex128]:
+ xt = x.astype(t)
+ xt -= 1j * xt
+ yt = y.astype(t)
+ yt -= 1j * yt
+ self._compare(xt, yt, np.equal, math_ops.equal)
+ self._compare(xt, yt, np.not_equal, math_ops.not_equal)
+
+ def _compareBCast(self, xs, ys, dtype, np_func, tf_func):
+ x = np.linspace(-15, 15, np.prod(xs)).astype(dtype).reshape(xs)
+ y = np.linspace(20, -10, np.prod(ys)).astype(dtype).reshape(ys)
+ if dtype in (np.complex64, np.complex128):
+ x -= 1j * x
+ y -= 1j * y
+ self._compare(x, y, np_func, tf_func)
+ self._compare(y, x, np_func, tf_func)
+
+ def _testBCastByFunc(self, np_func, tf_func, include_complex=False):
+ shapes = [
+ ([1, 3, 2], [1]),
+ ([1, 3, 2], [2]),
+ ([1, 3, 2], [3, 2]),
+ ([1, 3, 2], [3, 1]),
+ ([1, 3, 2], [1, 3, 2]),
+ ([1, 3, 2], [2, 3, 1]),
+ ([1, 3, 2], [2, 1, 1]),
+ ([1, 3, 2], [1, 3, 1]),
+ ([2, 1, 5], [2, 3, 1]),
+ ([2, 0, 5], [2, 0, 1]),
+ ([2, 3, 0], [2, 3, 1]),
+ ]
+ dtypes = [
+ np.float16,
+ np.float32,
+ np.float64,
+ np.int32,
+ np.int64,
+ ]
+ if include_complex:
+ dtypes.extend([np.complex64, np.complex128])
+
+ for (xs, ys) in shapes:
+ for dtype in dtypes:
+ self._compareBCast(xs, ys, dtype, np_func, tf_func)
+
+ def testBCastLess(self):
+ self._testBCastByFunc(np.less, math_ops.less)
+
+ def testBCastLessEqual(self):
+ self._testBCastByFunc(np.less_equal, math_ops.less_equal)
+
+ def testBCastGreater(self):
+ self._testBCastByFunc(np.greater, math_ops.greater)
+
+ def testBCastGreaterEqual(self):
+ self._testBCastByFunc(np.greater_equal, math_ops.greater_equal)
+
+ def testBCastEqual(self):
+ self._testBCastByFunc(np.equal, math_ops.equal, include_complex=True)
+
+ def testBCastNotEqual(self):
+ self._testBCastByFunc(
+ np.not_equal, math_ops.not_equal, include_complex=True)
+
+ def testShapeMismatch(self):
+ dtypes = [np.float16, np.float32, np.float64, np.int32, np.int64]
+ funcs = [
+ math_ops.less, math_ops.less_equal, math_ops.greater,
+ math_ops.greater_equal, math_ops.equal, math_ops.not_equal
+ ]
+ x = np.arange(0, 10).reshape([2, 5])
+ y = np.arange(0, 10).reshape([5, 2])
+ for t in dtypes:
+ for f in funcs:
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, lambda e: "Dimensions must" in str(e)):
+ f(x.astype(t), y.astype(t))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index 00d7f956c2..c5311ad834 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -18,25 +18,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import math
-
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib
-from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gradient_checker
-from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-from tensorflow.python.platform import tf_logging
_ADD = lambda x, y: x + y
_SUB = lambda x, y: x - y
@@ -45,8 +39,6 @@ _POW = lambda x, y: x**y
_TRUEDIV = lambda x, y: x / y
_FLOORDIV = lambda x, y: x // y
_MOD = lambda x, y: x % y
-_NEG = lambda x: -x
-_ABS = abs
_LT = lambda x, y: x < y
_LE = lambda x, y: x <= y
@@ -74,8 +66,11 @@ def _sparsify(x, thresh=0.5, index_dtype=np.int64):
def _default_tolerance(dtype):
- """Returns a sensible default tolerance for comparing results of a given
- type"""
+ """Returns a sensible default tolerance for comparing results of a given type.
+
+ Args:
+ dtype: A datatype.
+ """
if dtype == np.float16:
return 5e-3
elif dtype in (np.float32, np.complex64):
@@ -86,1147 +81,6 @@ def _default_tolerance(dtype):
return None # Fail fast for unexpected types
-class UnaryOpTest(test.TestCase):
-
- def _compareCpu(self, x, np_func, tf_func, grad_rtol=None, grad_atol=None):
- if grad_rtol is None:
- grad_rtol = _default_tolerance(x.dtype)
- if grad_atol is None:
- grad_atol = _default_tolerance(x.dtype)
- np_ans = np_func(x)
- with self.test_session(use_gpu=False):
- inx = ops.convert_to_tensor(x)
- if x.dtype in (np.float32, np.float64,
- dtypes_lib.bfloat16.as_numpy_dtype):
- y = 1.1 * tf_func(inx)
- np_ans *= 1.1
- else:
- y = tf_func(inx)
- tf_cpu = y.eval()
- self.assertShapeEqual(np_ans, y)
- if x.dtype == np.float16:
- self.assertAllClose(np_ans, tf_cpu, rtol=1e-3, atol=1e-3)
- elif x.dtype == dtypes_lib.bfloat16.as_numpy_dtype:
- self.assertAllClose(np_ans, tf_cpu, rtol=1e-2, atol=1e-2)
- else:
- self.assertAllClose(np_ans, tf_cpu)
-
- if x.dtype in (np.complex64, np.complex128) and tf_func == math_ops.sign:
- return # Return early
-
- if x.dtype == np.float16:
- s = list(np.shape(x))
- jacob_t, _ = gradient_checker.compute_gradient(
- inx, s, y, s, x_init_value=x)
- xf = x.astype(np.float)
- inxf = ops.convert_to_tensor(xf)
- yf = tf_func(inxf)
- _, jacob_n = gradient_checker.compute_gradient(
- inxf, s, yf, s, x_init_value=xf, delta=1e-2)
- jacob_n = jacob_n.astype(np.float16)
- self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
- elif x.dtype in (np.float32, np.complex64):
- s = list(np.shape(x))
- jacob_t, jacob_n = gradient_checker.compute_gradient(
- inx, s, y, s, x_init_value=x, delta=1e-3)
- self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
- elif x.dtype in (np.float64, np.complex128):
- s = list(np.shape(x))
- jacob_t, jacob_n = gradient_checker.compute_gradient(
- inx, s, y, s, x_init_value=x, delta=1e-5)
- self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
-
- def _check(self, result_tensor, result_np, input_sp_t, tol):
- self.assertTrue(isinstance(result_tensor, sparse_tensor.SparseTensor))
- self.assertTrue(isinstance(input_sp_t, sparse_tensor.SparseTensor))
- self.assertAllEqual(input_sp_t.indices.eval(), result_tensor.indices.eval())
- self.assertAllEqual(input_sp_t.dense_shape.eval(),
- result_tensor.dense_shape.eval())
- if tol is None:
- self.assertAllClose(result_np, result_tensor.values.eval())
- else:
- self.assertAllClose(
- result_np, result_tensor.values.eval(), rtol=tol, atol=tol)
-
- def _compareSparseCpu(self, x, np_func, tf_func, tol):
- x_sp, x_sp_vals = _sparsify(x)
- res_np = np_func(x_sp_vals)
- with self.test_session(use_gpu=False):
- self._check(tf_func(x_sp), res_np, x_sp, tol)
-
- def _compareGpu(self, x, np_func, tf_func):
- np_ans = np_func(x)
- with self.test_session(force_gpu=test_util.is_gpu_available()):
- result = tf_func(ops.convert_to_tensor(x))
- tf_gpu = result.eval()
- if x.dtype == np.float16:
- self.assertAllClose(np_ans, tf_gpu, rtol=1e-3, atol=1e-3)
- else:
- self.assertAllClose(np_ans, tf_gpu)
- # TODO(zhifengc/ke): make gradient checker work on GPU.
-
- def _compareSparseGpu(self, x, np_func, tf_func, tol):
- x_sp, x_sp_vals = _sparsify(x)
- res_np = np_func(x_sp_vals)
- with self.test_session(force_gpu=test_util.is_gpu_available()):
- self._check(tf_func(x_sp), res_np, x_sp, tol)
-
- def _compareBoth(self, x, np_func, tf_func):
- self._compareCpu(x, np_func, tf_func)
- self._compareGpu(x, np_func, tf_func)
-
- def _compareBothSparse(self, x, np_func, tf_func, tol=None):
- self._compareSparseCpu(x, np_func, tf_func, tol)
- self._compareSparseGpu(x, np_func, tf_func, tol)
-
- def _inv(self, x):
- return 1.0 / x
-
- def _rsqrt(self, x):
- return self._inv(np.sqrt(x))
-
- def _sigmoid(self, x):
- return 1.0 / (1.0 + np.exp(-x))
-
- def _log_sigmoid(self, x):
- return np.log(self._sigmoid(x))
-
- def _replace_domain_error_with_inf(self, fn):
-
- def func(x):
- try:
- return fn(x)
- except ValueError as e:
- if "domain error" in str(e):
- return np.inf * np.ones_like(x)
- else:
- raise e
-
- return func
-
- def testFloatBasic(self):
- x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32)
- w = x - x.min() + 1.02 # all greater than 1
- y = (x + .5).astype(np.float32) # no zero
- z = (x + 15.5).astype(np.float32) # all positive
- k = np.arange(-0.90, 0.90, 0.25).astype(np.float32) # between -1 and 1
-
- self._compareBoth(x, np.abs, math_ops.abs)
- self._compareBoth(x, np.abs, _ABS)
- self._compareBoth(x, np.negative, math_ops.negative)
- self._compareBoth(x, np.negative, _NEG)
- self._compareBoth(y, self._inv, math_ops.reciprocal)
- self._compareBoth(x, np.square, math_ops.square)
- self._compareBoth(z, np.sqrt, math_ops.sqrt)
- self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
- self._compareBoth(x, np.exp, math_ops.exp)
- self._compareBoth(x, np.expm1, math_ops.expm1)
- self._compareBoth(z, np.log, math_ops.log)
- self._compareBoth(z, np.log1p, math_ops.log1p)
- self._compareBoth(x, np.sinh, math_ops.sinh)
- self._compareBoth(x, np.cosh, math_ops.cosh)
- self._compareBoth(x, np.tanh, math_ops.tanh)
- self._compareBoth(x, np.arcsinh, math_ops.asinh)
- self._compareBoth(w, np.arccosh, math_ops.acosh)
- self._compareBoth(k, np.arctanh, math_ops.atanh)
- self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
- self._compareBoth(x, self._log_sigmoid, math_ops.log_sigmoid)
- self._compareBoth(y, np.sign, math_ops.sign)
- self._compareBoth(x, np.sin, math_ops.sin)
- self._compareBoth(x, np.cos, math_ops.cos)
- self._compareBoth(k, np.arcsin, math_ops.asin)
- self._compareBoth(k, np.arccos, math_ops.acos)
- self._compareBoth(x, np.arctan, math_ops.atan)
- self._compareBoth(x, np.tan, math_ops.tan)
- self._compareBoth(y,
- np.vectorize(
- self._replace_domain_error_with_inf(math.lgamma)),
- math_ops.lgamma)
- self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
- self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
- try:
- from scipy import special # pylint: disable=g-import-not-at-top
- self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
- self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
- except ImportError as e:
- tf_logging.warn("Cannot test special functions: %s" % str(e))
-
- self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.negative)
- self._compareBothSparse(x, np.square, math_ops.square)
- self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
- self._compareBothSparse(x, np.tanh, math_ops.tanh)
- self._compareBothSparse(y, np.sign, math_ops.sign)
- self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf)
-
- def testFloatTanhEdge(self):
- x = np.arange(40, 40 + 6).reshape(6).astype(np.float32)
- self._compareBoth(x, np.tanh, math_ops.tanh)
- x = np.arange(-40, -40 + 6).reshape(6).astype(np.float32)
- self._compareBoth(x, np.tanh, math_ops.tanh)
-
- def testFloatEmpty(self):
- x = np.empty((2, 0, 5), dtype=np.float32)
- self._compareBoth(x, np.abs, math_ops.abs)
- self._compareBoth(x, np.abs, _ABS)
- self._compareBoth(x, np.negative, math_ops.negative)
- self._compareBoth(x, np.negative, _NEG)
- self._compareBoth(x, self._inv, math_ops.reciprocal)
- self._compareBoth(x, np.square, math_ops.square)
- self._compareBoth(x, np.sqrt, math_ops.sqrt)
- self._compareBoth(x, self._rsqrt, math_ops.rsqrt)
- self._compareBoth(x, np.exp, math_ops.exp)
- self._compareBoth(x, np.expm1, math_ops.expm1)
- self._compareBoth(x, np.log, math_ops.log)
- self._compareBoth(x, np.log1p, math_ops.log1p)
- self._compareBoth(x, np.sinh, math_ops.sinh)
- self._compareBoth(x, np.arcsinh, math_ops.asinh)
- self._compareBoth(x, np.cosh, math_ops.cosh)
- self._compareBoth(x, np.tanh, math_ops.tanh)
- self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
- self._compareBoth(x, np.sign, math_ops.sign)
- self._compareBoth(x, np.sin, math_ops.sin)
- self._compareBoth(x, np.cos, math_ops.cos)
- # Can't use vectorize below, so just use some arbitrary function
- self._compareBoth(x, np.sign, math_ops.lgamma)
- self._compareBoth(x, np.sign, math_ops.erf)
- self._compareBoth(x, np.sign, math_ops.erfc)
- self._compareBoth(x, np.tan, math_ops.tan)
- self._compareBoth(x, np.arcsin, math_ops.asin)
- self._compareBoth(x, np.arccos, math_ops.acos)
- self._compareBoth(x, np.arctan, math_ops.atan)
- try:
- from scipy import special # pylint: disable=g-import-not-at-top
- self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
- self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
- except ImportError as e:
- tf_logging.warn("Cannot test special functions: %s" % str(e))
-
- self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.negative)
- self._compareBothSparse(x, np.square, math_ops.square)
- self._compareBothSparse(x, np.sqrt, math_ops.sqrt, tol=1e-3)
- self._compareBothSparse(x, np.tanh, math_ops.tanh)
- self._compareBothSparse(x, np.sign, math_ops.sign)
- self._compareBothSparse(x, np.sign, math_ops.erf)
-
- def testDoubleBasic(self):
- x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
- w = x - x.min() + 1.02 # all greater than 1
- y = (x + .5).astype(np.float64) # no zero
- z = (x + 15.5).astype(np.float64) # all positive
- k = np.arange(-0.90, 0.90,
- 0.35).reshape(1, 3, 2).astype(np.float64) # between -1 and 1
- self._compareBoth(x, np.abs, math_ops.abs)
- self._compareBoth(x, np.abs, _ABS)
- self._compareBoth(x, np.negative, math_ops.negative)
- self._compareBoth(x, np.negative, _NEG)
- self._compareBoth(y, self._inv, math_ops.reciprocal)
- self._compareBoth(x, np.square, math_ops.square)
- self._compareBoth(z, np.sqrt, math_ops.sqrt)
- self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
- self._compareBoth(x, np.exp, math_ops.exp)
- self._compareBoth(x, np.expm1, math_ops.expm1)
- self._compareBoth(z, np.log, math_ops.log)
- self._compareBoth(z, np.log1p, math_ops.log1p)
- self._compareBoth(x, np.sinh, math_ops.sinh)
- self._compareBoth(x, np.cosh, math_ops.cosh)
- self._compareBoth(x, np.tanh, math_ops.tanh)
- self._compareBoth(x, np.arcsinh, math_ops.asinh)
- self._compareBoth(w, np.arccosh, math_ops.acosh)
- self._compareBoth(k, np.arctanh, math_ops.atanh)
- self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
- self._compareBoth(y, np.sign, math_ops.sign)
- self._compareBoth(x, np.sin, math_ops.sin)
- self._compareBoth(x, np.cos, math_ops.cos)
- self._compareBoth(y,
- np.vectorize(
- self._replace_domain_error_with_inf(math.lgamma)),
- math_ops.lgamma)
- self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
- self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
- self._compareBoth(x, np.arctan, math_ops.atan)
- self._compareBoth(k, np.arcsin, math_ops.asin)
- self._compareBoth(k, np.arccos, math_ops.acos)
- self._compareBoth(k, np.tan, math_ops.tan)
- try:
- from scipy import special # pylint: disable=g-import-not-at-top
- self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
- self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
- except ImportError as e:
- tf_logging.warn("Cannot test special functions: %s" % str(e))
-
- self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.negative)
- self._compareBothSparse(x, np.square, math_ops.square)
- self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
- self._compareBothSparse(x, np.tanh, math_ops.tanh)
- self._compareBothSparse(y, np.sign, math_ops.sign)
- self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf)
-
- def testHalfBasic(self):
- x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float16)
- y = (x + .5).astype(np.float16) # no zero
- z = (x + 15.5).astype(np.float16) # all positive
- self._compareBoth(x, np.abs, math_ops.abs)
- self._compareBoth(x, np.abs, _ABS)
- self._compareBoth(x, np.negative, math_ops.negative)
- self._compareBoth(x, np.negative, _NEG)
- self._compareBoth(y, self._inv, math_ops.reciprocal)
- self._compareBoth(x, np.square, math_ops.square)
- self._compareBoth(z, np.sqrt, math_ops.sqrt)
- self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
- self._compareBoth(x, np.exp, math_ops.exp)
- self._compareBoth(x, np.expm1, math_ops.expm1)
- self._compareBoth(z, np.log, math_ops.log)
- self._compareBoth(z, np.log1p, math_ops.log1p)
- self._compareBoth(x, np.tanh, math_ops.tanh)
- self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
- self._compareBoth(y, np.sign, math_ops.sign)
- self._compareBoth(x, np.sin, math_ops.sin)
- self._compareBoth(x, np.cos, math_ops.cos)
- self._compareBoth(y,
- np.vectorize(
- self._replace_domain_error_with_inf(math.lgamma)),
- math_ops.lgamma)
- self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
- self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
- try:
- from scipy import special # pylint: disable=g-import-not-at-top
- self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
- self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
- except ImportError as e:
- tf_logging.warn("Cannot test special functions: %s" % str(e))
-
- self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.negative)
- self._compareBothSparse(x, np.square, math_ops.square)
- self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
- self._compareBothSparse(x, np.tanh, math_ops.tanh)
- self._compareBothSparse(y, np.sign, math_ops.sign)
- self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf, tol=1e-3)
-
- def testInt32Basic(self):
- x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32)
- self._compareCpu(x, np.abs, math_ops.abs)
- self._compareCpu(x, np.abs, _ABS)
- self._compareBoth(x, np.negative, math_ops.negative)
- self._compareBoth(x, np.negative, _NEG)
- self._compareBoth(x, np.square, math_ops.square)
- self._compareCpu(x, np.sign, math_ops.sign)
-
- self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.negative)
- self._compareBothSparse(x, np.square, math_ops.square)
- self._compareBothSparse(x, np.sign, math_ops.sign)
-
- def testInt64Basic(self):
- x = np.arange(-6 << 40, 6 << 40, 2 << 40).reshape(1, 3, 2).astype(np.int64)
- self._compareCpu(x, np.abs, math_ops.abs)
- self._compareCpu(x, np.abs, _ABS)
- self._compareCpu(x, np.negative, math_ops.negative)
- self._compareCpu(x, np.negative, _NEG)
- self._compareCpu(x, np.sign, math_ops.sign)
-
- self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.negative)
- self._compareBothSparse(x, np.sign, math_ops.sign)
-
- def testInt64Square(self):
- x = np.arange(-6 << 20, 6 << 20, 2 << 20).reshape(1, 3, 2).astype(np.int64)
- self._compareCpu(x, np.square, math_ops.square)
- self._compareBothSparse(x, np.square, math_ops.square)
-
- def testComplex64Basic(self):
- x = np.complex(1, 1) * np.arange(-3, 3).reshape(1, 3, 2).astype(
- np.complex64)
- y = x + np.complex(0.5, 0.5) # no zeros
- self._compareBoth(x, np.abs, math_ops.abs)
- self._compareBoth(x, np.abs, _ABS)
- self._compareBoth(x, np.negative, math_ops.negative)
- self._compareBoth(x, np.negative, _NEG)
- self._compareCpu(y, self._inv, math_ops.reciprocal)
- self._compareCpu(x, np.square, math_ops.square)
- self._compareCpu(y, np.sqrt, math_ops.sqrt)
- self._compareCpu(y, self._rsqrt, math_ops.rsqrt)
- self._compareBoth(x, np.exp, math_ops.exp)
- self._compareCpu(x, np.expm1, math_ops.expm1)
- self._compareCpu(y, np.log, math_ops.log)
- self._compareCpu(y, np.log1p, math_ops.log1p)
- self._compareCpu(x, np.sinh, math_ops.sinh)
- self._compareCpu(x, np.cosh, math_ops.cosh)
- self._compareCpu(x, np.tanh, math_ops.tanh)
-
- # Complex64 versions of asinh() and acosh() in libstdc++ only have 6 digits
- # of precision.
- # Small gradient values + low precision --> High relative error
- self._compareCpu(y, np.arcsinh, math_ops.asinh, grad_rtol=1e-2)
- self._compareCpu(y, np.arccosh, math_ops.acosh, grad_rtol=1e-2)
-
- self._compareCpu(y, np.arctanh, math_ops.atanh)
- self._compareCpu(x, self._sigmoid, math_ops.sigmoid)
- self._compareCpu(x, np.sin, math_ops.sin)
- self._compareCpu(x, np.cos, math_ops.cos)
-
- self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.negative)
- self._compareBothSparse(x, np.square, math_ops.square)
- self._compareBothSparse(x, np.sqrt, math_ops.sqrt, 1e-3)
- self._compareBothSparse(x, np.tanh, math_ops.tanh)
-
- # Numpy uses an incorrect definition of sign; use the right one instead.
- def complex_sign(x):
- return x / np.abs(x)
-
- self._compareBoth(y, complex_sign, math_ops.sign)
- self._compareBothSparse(y, complex_sign, math_ops.sign)
-
- def testComplex128Basic(self):
- x = np.complex(1, 1) * np.arange(-3, 3).reshape(1, 3, 2).astype(
- np.complex128)
- y = x + np.complex(0.5, 0.5) # no zeros
- self._compareBoth(x, np.abs, math_ops.abs)
- self._compareBoth(x, np.abs, _ABS)
- self._compareBoth(x, np.negative, math_ops.negative)
- self._compareBoth(x, np.negative, _NEG)
- self._compareCpu(y, self._inv, math_ops.reciprocal)
- self._compareCpu(x, np.square, math_ops.square)
- self._compareCpu(y, np.sqrt, math_ops.sqrt)
- self._compareCpu(y, self._rsqrt, math_ops.rsqrt)
- self._compareBoth(x, np.exp, math_ops.exp)
- self._compareCpu(x, np.expm1, math_ops.expm1)
- self._compareCpu(y, np.log, math_ops.log)
- self._compareCpu(y, np.log1p, math_ops.log1p)
- self._compareCpu(x, np.sinh, math_ops.sinh)
- self._compareCpu(x, np.cosh, math_ops.cosh)
- self._compareCpu(x, np.tanh, math_ops.tanh)
- self._compareCpu(y, np.arcsinh, math_ops.asinh)
- self._compareCpu(y, np.arccosh, math_ops.acosh)
- self._compareCpu(y, np.arctanh, math_ops.atanh)
- self._compareCpu(x, self._sigmoid, math_ops.sigmoid)
- self._compareCpu(x, np.sin, math_ops.sin)
- self._compareCpu(x, np.cos, math_ops.cos)
-
- self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.negative)
- self._compareBothSparse(x, np.square, math_ops.square)
- self._compareBothSparse(x, np.sqrt, math_ops.sqrt, 1e-3)
- self._compareBothSparse(x, np.tanh, math_ops.tanh)
-
- # Numpy uses an incorrect definition of sign; use the right one instead.
- def complex_sign(x):
- return x / np.abs(x)
-
- self._compareBoth(y, complex_sign, math_ops.sign)
- self._compareBothSparse(y, complex_sign, math_ops.sign)
-
- def testGradGrad(self):
- np.random.seed(7)
- shape = (5,)
- dtype_tols = [(np.float32, 5e-4), (np.float64, 1e-6), (np.complex64, 5e-4),
- (np.complex128, 1e-6)]
- op_range = [
- (gen_math_ops.reciprocal_grad, [-2, 2]),
- (gen_math_ops.rsqrt_grad, [0.1, 3]),
- (gen_math_ops.sigmoid_grad, [-2, 2]),
- (gen_math_ops.sqrt_grad, [0.1, 3]),
- (gen_math_ops.tanh_grad, [-2, 2]),
- ]
-
- def rand(dtype):
- x = np.random.uniform(
- real_range[0], real_range[1], size=shape[0]).astype(dtype)
- if dtype in (np.complex64, np.complex128):
- x += 1j * np.random.uniform(-2, 2, size=shape[0]).astype(dtype)
- return x
-
- for op, real_range in op_range:
- with self.cached_session():
- for dtype, tol in dtype_tols:
- x = constant_op.constant(rand(dtype))
- y = constant_op.constant(rand(dtype))
- z = op(x, y)
- grads = gradient_checker.compute_gradient(
- [x, y], [shape, shape],
- z,
- shape,
- x_init_value=[rand(dtype), rand(dtype)])
- if isinstance(grads, tuple):
- grads = [grads]
- for analytical, numerical in grads:
- self.assertAllClose(analytical, numerical, rtol=tol, atol=tol)
-
-
-class BinaryOpTest(test.TestCase):
-
- def _compareCpu(self, x, y, np_func, tf_func, also_compare_variables=False):
- np_ans = np_func(x, y)
- with self.test_session(use_gpu=False):
- inx = ops.convert_to_tensor(x)
- iny = ops.convert_to_tensor(y)
- out = tf_func(inx, iny)
- tf_cpu = out.eval()
- # Test that the op takes precedence over numpy operators.
- np_left = tf_func(x, iny).eval()
- np_right = tf_func(inx, y).eval()
-
- if also_compare_variables:
- var_x = variables.Variable(x)
- var_y = variables.Variable(y)
- variables.global_variables_initializer().run()
- print(type(x), type(y), type(var_x), type(var_y))
- print(type(tf_func(x, var_y)), type(tf_func(var_x, y)))
- np_var_left = tf_func(x, var_y).eval()
- np_var_right = tf_func(var_x, y).eval()
-
- if np_ans.dtype != np.object:
- self.assertAllClose(np_ans, tf_cpu)
- self.assertAllClose(np_ans, np_left)
- self.assertAllClose(np_ans, np_right)
- if also_compare_variables:
- self.assertAllClose(np_ans, np_var_left)
- self.assertAllClose(np_ans, np_var_right)
- self.assertShapeEqual(np_ans, out)
-
- _GRAD_TOL = {
- dtypes_lib.float16: 1e-3,
- dtypes_lib.float32: 1e-3,
- dtypes_lib.complex64: 1e-2,
- dtypes_lib.float64: 1e-5,
- dtypes_lib.complex128: 1e-4
- }
-
- def _compareGradientX(self,
- x,
- y,
- np_func,
- tf_func,
- numeric_gradient_type=None):
- z = np_func(x, y)
- zs = list(z.shape)
- with self.cached_session():
- inx = ops.convert_to_tensor(x)
- iny = ops.convert_to_tensor(y)
- if x.dtype in (np.float32, np.float64):
- out = 1.1 * tf_func(inx, iny)
- else:
- out = tf_func(inx, iny)
- xs = list(x.shape)
- jacob_t, jacob_n = gradient_checker.compute_gradient(
- inx, xs, out, zs, x_init_value=x)
- if numeric_gradient_type is not None:
- xf = x.astype(numeric_gradient_type)
- yf = y.astype(numeric_gradient_type)
- inxf = ops.convert_to_tensor(xf)
- inyf = ops.convert_to_tensor(yf)
- outf = tf_func(inxf, inyf)
- _, jacob_n = gradient_checker.compute_gradient(
- inxf, xs, outf, zs, x_init_value=xf, delta=1e-3)
- jacob_n = jacob_n.astype(x.dtype)
- tol = self._GRAD_TOL[dtypes_lib.as_dtype(x.dtype)]
- self.assertAllClose(jacob_t, jacob_n, rtol=tol, atol=tol)
-
- def _compareGradientY(self,
- x,
- y,
- np_func,
- tf_func,
- numeric_gradient_type=None):
- z = np_func(x, y)
- zs = list(z.shape)
- with self.cached_session():
- inx = ops.convert_to_tensor(x)
- iny = ops.convert_to_tensor(y)
- if x.dtype in (np.float32, np.float64):
- out = 1.1 * tf_func(inx, iny)
- else:
- out = tf_func(inx, iny)
- ys = list(np.shape(y))
- jacob_t, jacob_n = gradient_checker.compute_gradient(
- iny, ys, out, zs, x_init_value=y)
- if numeric_gradient_type is not None:
- xf = x.astype(numeric_gradient_type)
- yf = y.astype(numeric_gradient_type)
- inxf = ops.convert_to_tensor(xf)
- inyf = ops.convert_to_tensor(yf)
- outf = tf_func(inxf, inyf)
- _, jacob_n = gradient_checker.compute_gradient(
- inyf, ys, outf, zs, x_init_value=yf)
- jacob_n = jacob_n.astype(x.dtype)
- tol = self._GRAD_TOL[dtypes_lib.as_dtype(x.dtype)]
- self.assertAllClose(jacob_t, jacob_n, rtol=tol, atol=tol)
-
- def _compareGpu(self, x, y, np_func, tf_func):
- np_ans = np_func(x, y)
- with self.test_session(force_gpu=test_util.is_gpu_available()):
- inx = ops.convert_to_tensor(x)
- iny = ops.convert_to_tensor(y)
- out = tf_func(inx, iny)
- tf_gpu = out.eval()
- self.assertAllClose(np_ans, tf_gpu)
- self.assertShapeEqual(np_ans, out)
- # TODO(zhifengc/ke): make gradient checker work on GPU.
-
- def _compareBoth(self, x, y, np_func, tf_func, also_compare_variables=False):
- self._compareCpu(x, y, np_func, tf_func, also_compare_variables)
- if x.dtype in (np.float16, np.float32, np.float64, np.complex64,
- np.complex128):
- if tf_func not in (_FLOORDIV, math_ops.floordiv, math_ops.zeta,
- math_ops.polygamma):
- self._compareGradientX(x, y, np_func, tf_func)
- self._compareGradientY(x, y, np_func, tf_func)
- if tf_func in (math_ops.zeta, math_ops.polygamma):
- # These methods only support gradients in the second parameter
- self._compareGradientY(x, y, np_func, tf_func)
- self._compareGpu(x, y, np_func, tf_func)
-
- def testFloatBasic(self):
- x = np.linspace(-5, 20, 15).reshape(1, 3, 5).astype(np.float32)
- y = np.linspace(20, -5, 15).reshape(1, 3, 5).astype(np.float32)
- self._compareBoth(x, y, np.add, math_ops.add, also_compare_variables=True)
- self._compareBoth(x, y, np.subtract, math_ops.subtract)
- self._compareBoth(x, y, np.multiply, math_ops.multiply)
- self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
- self._compareBoth(x, y + 0.1, np.floor_divide, math_ops.floordiv)
- self._compareBoth(x, y, np.add, _ADD)
- self._compareBoth(x, y, np.subtract, _SUB)
- self._compareBoth(x, y, np.multiply, _MUL)
- self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
- self._compareBoth(x, y + 0.1, np.floor_divide, _FLOORDIV)
- self._compareBoth(x, y, np.arctan2, math_ops.atan2)
- x1 = np.random.randn(5, 6).astype(np.float32)
- x2 = np.random.randn(5, 6).astype(np.float32)
- # Remove tiny values--atan2 gradients are flaky near the origin.
- x1[np.abs(x1) < 0.05] = 0.05 * np.sign(x1[np.abs(x1) < 0.05])
- x2[np.abs(x2) < 0.05] = 0.05 * np.sign(x2[np.abs(x2) < 0.05])
- self._compareBoth(x1, x2, np.arctan2, math_ops.atan2)
- try:
- from scipy import special # pylint: disable=g-import-not-at-top
- a_pos_small = np.linspace(0.1, 2, 15).reshape(1, 3, 5).astype(np.float32)
- x_pos_small = np.linspace(0.1, 10, 15).reshape(1, 3, 5).astype(np.float32)
- self._compareBoth(a_pos_small, x_pos_small, special.gammainc,
- math_ops.igamma)
- self._compareBoth(a_pos_small, x_pos_small, special.gammaincc,
- math_ops.igammac)
- # Need x > 1
- self._compareBoth(x_pos_small + 1, a_pos_small, special.zeta,
- math_ops.zeta)
- n_small = np.arange(0, 15).reshape(1, 3, 5).astype(np.float32)
- self._compareBoth(n_small, x_pos_small, special.polygamma,
- math_ops.polygamma)
- except ImportError as e:
- tf_logging.warn("Cannot test special functions: %s" % str(e))
-
- def testFloatDifferentShapes(self):
- x = np.array([1, 2, 3, 4]).reshape(2, 2).astype(np.float32)
- y = np.array([1, 2]).reshape(2, 1).astype(np.float32)
- with self.cached_session() as sess:
- inx = ops.convert_to_tensor(x)
- iny = ops.convert_to_tensor(y)
- s = math_ops.reduce_sum(inx * iny)
- gx, gy = sess.run(gradients_impl.gradients(s, [inx, iny]))
- # gx is simply the broadcasted y
- self.assertAllEqual(gx,
- np.array([1, 1, 2, 2]).reshape(2, 2).astype(np.float32))
- # gy is x's column summed up
- self.assertAllEqual(gy, np.array([3, 7]).reshape(2, 1).astype(np.float32))
-
- def testFloatVariableOverload(self):
- x = np.array([1, 2, 3, 4]).reshape(2, 2).astype(np.int32)
- y = np.array([1, 2]).reshape(2, 1).astype(np.int32)
- var_x = variables.Variable(x)
- var_y = variables.Variable(y)
- with self.cached_session() as sess:
- sess.run([var_x.initializer, var_y.initializer])
- left_result = (var_x * y).eval()
- right_result = (x * var_y).eval()
- np_result = x * y
- self.assertAllEqual(np_result, left_result)
- self.assertAllEqual(np_result, right_result)
-
- def testDoubleBasic(self):
- x = np.linspace(-5, 20, 15).reshape(1, 3, 5).astype(np.float64)
- y = np.linspace(20, -5, 15).reshape(1, 3, 5).astype(np.float64)
- self._compareBoth(x, y, np.add, math_ops.add)
- self._compareBoth(x, y, np.subtract, math_ops.subtract)
- self._compareBoth(x, y, np.multiply, math_ops.multiply)
- self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
- self._compareBoth(x, y + 0.1, np.floor_divide, math_ops.floordiv)
- self._compareBoth(x, y, np.add, _ADD)
- self._compareBoth(x, y, np.subtract, _SUB)
- self._compareBoth(x, y, np.multiply, _MUL)
- self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
- self._compareBoth(x, y + 0.1, np.floor_divide, _FLOORDIV)
- self._compareBoth(x, y, np.arctan2, math_ops.atan2)
- x1 = np.random.randn(7, 4).astype(np.float64)
- x2 = np.random.randn(7, 4).astype(np.float64)
- # Remove tiny values--atan2 gradients are flaky near the origin.
- x1[np.abs(x1) < 0.5] = 0.5 * np.sign(x1[np.abs(x1) < 0.5])
- x2[np.abs(x2) < 0.5] = 0.5 * np.sign(x2[np.abs(x2) < 0.5])
- self._compareBoth(x1, x2, np.arctan2, math_ops.atan2)
- try:
- from scipy import special # pylint: disable=g-import-not-at-top
- a_pos_small = np.linspace(0.1, 2, 15).reshape(1, 3, 5).astype(np.float32)
- x_pos_small = np.linspace(0.1, 10, 15).reshape(1, 3, 5).astype(np.float32)
- self._compareBoth(a_pos_small, x_pos_small, special.gammainc,
- math_ops.igamma)
- self._compareBoth(a_pos_small, x_pos_small, special.gammaincc,
- math_ops.igammac)
- except ImportError as e:
- tf_logging.warn("Cannot test special functions: %s" % str(e))
-
- def testUint8Basic(self):
- x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.uint8)
- y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.uint8)
- self._compareBoth(x, y, np.add, math_ops.add)
-
- def testInt8Basic(self):
- x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int8)
- y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int8)
- self._compareBoth(x, y, np.multiply, math_ops.multiply)
- self._compareBoth(x, y, np.multiply, _MUL)
-
- def testInt16Basic(self):
- x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int16)
- y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int16)
- self._compareBoth(x, y, np.multiply, math_ops.multiply)
- self._compareBoth(x, y, np.multiply, _MUL)
-
- def testUint16Basic(self):
- x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.uint16)
- y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.uint16)
- self._compareBoth(x, y, np.multiply, math_ops.multiply)
- self._compareBoth(x, y, np.multiply, _MUL)
- self._compareBoth(x, y, np.true_divide, math_ops.truediv)
- self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
- self._compareBoth(x, y, np.true_divide, _TRUEDIV)
- self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
-
- def testInt32Basic(self):
- x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int32)
- y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int32)
- self._compareBoth(x, y, np.add, math_ops.add)
- self._compareBoth(x, y, np.subtract, math_ops.subtract)
- self._compareBoth(x, y, np.multiply, math_ops.multiply)
- self._compareBoth(x, y, np.true_divide, math_ops.truediv)
- self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
- self._compareBoth(x, y, np.mod, math_ops.mod)
- self._compareBoth(x, y, np.add, _ADD)
- self._compareBoth(x, y, np.subtract, _SUB)
- self._compareBoth(x, y, np.multiply, _MUL)
- self._compareBoth(x, y, np.true_divide, _TRUEDIV)
- self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
- self._compareBoth(x, y, np.mod, _MOD)
- # _compareBoth tests on GPU only for floating point types, so test
- # _MOD for int32 on GPU by calling _compareGpu
- self._compareGpu(x, y, np.mod, _MOD)
-
- def testInt64Basic(self):
- x = np.arange(1 << 40, 13 << 40, 2 << 40).reshape(1, 3, 2).astype(np.int64)
- y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int64)
- self._compareBoth(x, y, np.subtract, math_ops.subtract)
- self._compareBoth(x, y, np.multiply, math_ops.multiply)
- self._compareBoth(x, y, np.true_divide, math_ops.truediv)
- self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
- self._compareBoth(x, y, np.mod, math_ops.mod)
- self._compareBoth(x, y, np.subtract, _SUB)
- self._compareBoth(x, y, np.multiply, _MUL)
- self._compareBoth(x, y, np.true_divide, _TRUEDIV)
- self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
- self._compareBoth(x, y, np.mod, _MOD)
-
- def testComplex64Basic(self):
- x = np.complex(1, 1) * np.linspace(-10, 10, 6).reshape(1, 3, 2).astype(
- np.complex64)
- y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape(1, 3, 2).astype(
- np.complex64)
- self._compareBoth(x, y, np.add, math_ops.add)
- self._compareBoth(x, y, np.subtract, math_ops.subtract)
- self._compareBoth(x, y, np.multiply, math_ops.multiply)
- self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
- self._compareBoth(x, y, np.add, _ADD)
- self._compareBoth(x, y, np.subtract, _SUB)
- self._compareBoth(x, y, np.multiply, _MUL)
- self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
-
- def testComplex128Basic(self):
- x = np.complex(1, 1) * np.linspace(-10, 10, 6).reshape(1, 3, 2).astype(
- np.complex128)
- y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape(1, 3, 2).astype(
- np.complex128)
- self._compareBoth(x, y, np.add, math_ops.add)
- self._compareBoth(x, y, np.subtract, math_ops.subtract)
- self._compareBoth(x, y, np.multiply, math_ops.multiply)
- self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
- self._compareBoth(x, y, np.add, _ADD)
- self._compareBoth(x, y, np.subtract, _SUB)
- self._compareBoth(x, y, np.multiply, _MUL)
- self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
-
- def testStringComparison(self):
- x = np.array([["abc", "bh"], ["c", ""]])
- y = np.array([["abc", "bh"], ["def", "hi"]])
- with self.test_session(use_gpu=False) as sess:
- cmp_eq = math_ops.equal(x, y)
- cmp_not_eq = math_ops.not_equal(x, y)
- values = sess.run([cmp_eq, cmp_not_eq])
- self.assertAllEqual([[True, True], [False, False]], values[0])
- self.assertAllEqual([[False, False], [True, True]], values[1])
-
- def testString(self):
- x = np.array(
- [["x_0_0", "x_0_1", "x_0_2"], ["x_1_0", "x_1_1", "x_1_2"],
- ["x_2_0", "x_2_1", "x_2_2"]],
- dtype=np.object)
- y = np.array(
- [["y_0_0", "y_0_1", "y_0_2"], ["y_1_0", "y_1_1", "y_1_2"],
- ["y_2_0", "y_2_1", "y_2_2"]],
- dtype=np.object)
- z = np.array([["z_0", "z_1", "z_2"]], dtype=np.object)
- w = np.array("w", dtype=np.object)
- self._compareCpu(x, y, _ADD, _ADD)
- self._compareCpu(x, z, _ADD, _ADD)
- self._compareCpu(x, w, _ADD, _ADD)
- self._compareCpu(z, w, _ADD, _ADD)
-
- def _compareBCast(self, xs, ys, dtype, np_func, tf_func):
- if dtype in (np.complex64, np.complex128):
- x = (1 + np.linspace(0, 2 + 3j, np.prod(xs))).astype(dtype).reshape(xs)
- y = (1 + np.linspace(0, 2 - 2j, np.prod(ys))).astype(dtype).reshape(ys)
- else:
- x = (1 + np.linspace(0, 5, np.prod(xs))).astype(dtype).reshape(xs)
- y = (1 + np.linspace(0, 5, np.prod(ys))).astype(dtype).reshape(ys)
- self._compareCpu(x, y, np_func, tf_func)
- if x.dtype in (np.float16, np.float32, np.float64):
- # TODO(aselle): Make the test work for dtypes:
- # (np.complex64, np.complex128).
- if tf_func not in (_FLOORDIV, math_ops.floordiv):
- if x.dtype == np.float16:
- # Compare fp16 theoretical gradients to fp32 numerical gradients,
- # since fp16 numerical gradients are too imprecise unless great
- # care is taken with choosing the inputs and the delta. This is
- # a weaker check (in particular, it does not test the op itself,
- # only its gradient), but it's much better than nothing.
- self._compareGradientX(x, y, np_func, tf_func, np.float)
- self._compareGradientY(x, y, np_func, tf_func, np.float)
- else:
- self._compareGradientX(x, y, np_func, tf_func)
- self._compareGradientY(x, y, np_func, tf_func)
- self._compareGpu(x, y, np_func, tf_func)
-
- # TODO(josh11b,vrv): Refactor this to use parameterized tests.
- def _testBCastByFunc(self, funcs, xs, ys):
- dtypes = [
- np.float16,
- np.float32,
- np.float64,
- np.int32,
- np.int64,
- np.complex64,
- np.complex128,
- ]
- for dtype in dtypes:
- for (np_func, tf_func) in funcs:
- if (dtype in (np.complex64, np.complex128) and
- tf_func in (_FLOORDIV, math_ops.floordiv)):
- continue # floordiv makes no sense for complex numbers
- self._compareBCast(xs, ys, dtype, np_func, tf_func)
- self._compareBCast(ys, xs, dtype, np_func, tf_func)
-
- def _testBCastA(self, xs, ys):
- funcs = [
- (np.add, math_ops.add),
- (np.add, _ADD),
- ]
- self._testBCastByFunc(funcs, xs, ys)
-
- def _testBCastB(self, xs, ys):
- funcs = [
- (np.subtract, math_ops.subtract),
- (np.subtract, _SUB),
- (np.power, math_ops.pow),
- ]
- self._testBCastByFunc(funcs, xs, ys)
-
- def _testBCastC(self, xs, ys):
- funcs = [
- (np.multiply, math_ops.multiply),
- (np.multiply, _MUL),
- ]
- self._testBCastByFunc(funcs, xs, ys)
-
- def _testBCastD(self, xs, ys):
- funcs = [
- (np.true_divide, math_ops.truediv),
- (np.floor_divide, math_ops.floordiv),
- (np.true_divide, _TRUEDIV),
- (np.floor_divide, _FLOORDIV),
- ]
- self._testBCastByFunc(funcs, xs, ys)
-
- def testBCast_0A(self):
- self._testBCastA([1, 3, 2], [1])
-
- def testBCast_0B(self):
- self._testBCastB([1, 3, 2], [1])
-
- def testBCast_0C(self):
- self._testBCastC([1, 3, 2], [1])
-
- def testBCast_0D(self):
- self._testBCastD([1, 3, 2], [1])
-
- def testBCast_1A(self):
- self._testBCastA([1, 3, 2], [2])
-
- def testBCast_1B(self):
- self._testBCastB([1, 3, 2], [2])
-
- def testBCast_1C(self):
- self._testBCastC([1, 3, 2], [2])
-
- def testBCast_1D(self):
- self._testBCastD([1, 3, 2], [2])
-
- def testBCast_2A(self):
- self._testBCastA([1, 3, 2], [3, 2])
-
- def testBCast_2B(self):
- self._testBCastB([1, 3, 2], [3, 2])
-
- def testBCast_2C(self):
- self._testBCastC([1, 3, 2], [3, 2])
-
- def testBCast_2D(self):
- self._testBCastD([1, 3, 2], [3, 2])
-
- def testBCast_3A(self):
- self._testBCastA([1, 3, 2], [3, 1])
-
- def testBCast_3B(self):
- self._testBCastB([1, 3, 2], [3, 1])
-
- def testBCast_3C(self):
- self._testBCastC([1, 3, 2], [3, 1])
-
- def testBCast_3D(self):
- self._testBCastD([1, 3, 2], [3, 1])
-
- def testBCast_4A(self):
- self._testBCastA([1, 3, 2], [1, 3, 2])
-
- def testBCast_4B(self):
- self._testBCastB([1, 3, 2], [1, 3, 2])
-
- def testBCast_4C(self):
- self._testBCastC([1, 3, 2], [1, 3, 2])
-
- def testBCast_4D(self):
- self._testBCastD([1, 3, 2], [1, 3, 2])
-
- def testBCast_5A(self):
- self._testBCastA([1, 3, 2], [2, 3, 1])
-
- def testBCast_5B(self):
- self._testBCastB([1, 3, 2], [2, 3, 1])
-
- def testBCast_5C(self):
- self._testBCastC([1, 3, 2], [2, 3, 1])
-
- def testBCast_5D(self):
- self._testBCastD([1, 3, 2], [2, 3, 1])
-
- def testBCast_6A(self):
- self._testBCastA([1, 3, 2], [2, 1, 1])
-
- def testBCast_6B(self):
- self._testBCastB([1, 3, 2], [2, 1, 1])
-
- def testBCast_6C(self):
- self._testBCastC([1, 3, 2], [2, 1, 1])
-
- def testBCast_6D(self):
- self._testBCastD([1, 3, 2], [2, 1, 1])
-
- def testBCast_7A(self):
- self._testBCastA([1, 3, 2], [1, 3, 1])
-
- def testBCast_7B(self):
- self._testBCastB([1, 3, 2], [1, 3, 1])
-
- def testBCast_7C(self):
- self._testBCastC([1, 3, 2], [1, 3, 1])
-
- def testBCast_7D(self):
- self._testBCastD([1, 3, 2], [1, 3, 1])
-
- def testBCast_8A(self):
- self._testBCastA([2, 1, 5], [2, 3, 1])
-
- def testBCast_8B(self):
- self._testBCastB([2, 1, 5], [2, 3, 1])
-
- def testBCast_8C(self):
- self._testBCastC([2, 1, 5], [2, 3, 1])
-
- def testBCast_8D(self):
- self._testBCastD([2, 1, 5], [2, 3, 1])
-
- def testBCast_9A(self):
- self._testBCastA([2, 0, 5], [2, 0, 1])
-
- def testBCast_9B(self):
- self._testBCastB([2, 0, 5], [2, 0, 1])
-
- def testBCast_9C(self):
- self._testBCastC([2, 0, 5], [2, 0, 1])
-
- def testBCast_9D(self):
- self._testBCastD([2, 0, 5], [2, 0, 1])
-
- def testBCast_10A(self):
- self._testBCastA([2, 3, 0], [2, 3, 1])
-
- def testBCast_10B(self):
- self._testBCastB([2, 3, 0], [2, 3, 1])
-
- def testBCast_10C(self):
- self._testBCastC([2, 3, 0], [2, 3, 1])
-
- def testBCast_10D(self):
- self._testBCastD([2, 3, 0], [2, 3, 1])
-
- def testBCast_11A(self):
- self._testBCastA([1, 3, 2], [1, 3, 2])
-
- def testBCast_11B(self):
- self._testBCastB([1, 3, 2], [1, 3, 2])
-
- def testBCast_11C(self):
- self._testBCastC([1, 3, 2], [1, 3, 2])
-
- def testBCast_11D(self):
- self._testBCastD([1, 3, 2], [1, 3, 2])
-
- def testBCast_12A(self):
- self._testBCastA([1, 1, 1, 1, 3, 2], [1, 3, 2])
-
- def testBCast_12B(self):
- self._testBCastB([1, 1, 1, 1, 3, 2], [1, 3, 2])
-
- def testBCast_12C(self):
- self._testBCastC([1, 1, 1, 1, 3, 2], [1, 3, 2])
-
- def testBCast_12D(self):
- self._testBCastD([1, 1, 1, 1, 3, 2], [1, 3, 2])
-
- def testBCast_13A(self):
- self._testBCastA([1, 3, 2, 1, 1], [1])
-
- def testBCast_13B(self):
- self._testBCastB([1, 3, 2, 1, 1], [1])
-
- def testBCast_13C(self):
- self._testBCastC([1, 3, 2, 1, 1], [1])
-
- def testBCast_13D(self):
- self._testBCastD([1, 3, 2, 1, 1], [1])
-
- def testBCast_14A(self):
- self._testBCastA([2, 3, 1, 1, 5], [1])
-
- def testBCast_14B(self):
- self._testBCastB([2, 3, 1, 1, 5], [1])
-
- def testBCast_14C(self):
- self._testBCastC([2, 3, 1, 1, 5], [1])
-
- def testBCast_14D(self):
- self._testBCastD([2, 3, 1, 1, 5], [1])
-
- def testBCast_15A(self):
- self._testBCastA([10, 3, 1, 2], [3, 1, 2])
-
- def testBCast_15B(self):
- self._testBCastB([10, 3, 1, 2], [3, 1, 2])
-
- def testBCast_15C(self):
- self._testBCastC([10, 3, 1, 2], [3, 1, 2])
-
- def testBCast_15D(self):
- self._testBCastD([10, 3, 1, 2], [3, 1, 2])
-
- def testMismatchedDimensions(self):
- for func in [
- math_ops.add, math_ops.subtract, math_ops.multiply, math_ops.div, _ADD,
- _SUB, _MUL, _TRUEDIV, _FLOORDIV
- ]:
- with self.assertRaisesWithPredicateMatch(
- ValueError, lambda e: "Dimensions must" in str(e)):
- func(
- ops.convert_to_tensor([10.0, 20.0, 30.0]),
- ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]]))
-
- def testZeroPowGrad(self):
- with self.cached_session():
- for dtype in (np.float16, np.float32, np.float64, np.complex64,
- np.complex128):
- x = constant_op.constant(0.0, dtype=dtype)
- y = constant_op.constant(2.0, dtype=dtype)
- z = math_ops.pow(x, y)
- error = gradient_checker.compute_gradient_error(y, [], z, [])
- self.assertEqual(error, 0)
-
- def testComplexPowGrad(self):
- with self.cached_session():
- for dtype in np.complex64, np.complex128:
- for base in 2.0, -2.0:
- x = constant_op.constant(base, dtype=dtype)
- y = constant_op.constant(2.0, dtype=dtype)
- z = math_ops.pow(x, y)
- error = gradient_checker.compute_gradient_error(y, [], z, [])
- self.assertLess(error, 2e-4)
-
- def testAtan2SpecialValues(self):
- x1l, x2l = zip((+0.0, +0.0), (+0.0, -0.0), (-0.0, +0.0), (-0.0, -0.0),
- (1.2345, float("inf")), (1.2345, -float("inf")),
- (-4.321, float("inf")), (-4.125, -float("inf")),
- (float("inf"), float("inf")), (float("inf"), -float("inf")),
- (-float("inf"), float("inf")),
- (-float("inf"), -float("inf")))
- for dtype in np.float32, np.float64:
- x1 = np.array(x1l).astype(dtype)
- x2 = np.array(x2l).astype(dtype)
- self._compareCpu(x1, x2, np.arctan2, math_ops.atan2)
- self._compareGpu(x1, x2, np.arctan2, math_ops.atan2)
-
- def testPowNegativeExponent(self):
- for dtype in [np.int32, np.int64]:
- with self.test_session(use_gpu=False) as sess:
- with self.assertRaisesRegexp(
- errors_impl.InvalidArgumentError,
- "Integers to negative integer powers are not allowed"):
- x = np.array([5, 2]).astype(dtype)
- y = np.array([-2, 3]).astype(dtype)
- sess.run(math_ops.pow(x, y))
-
- with self.test_session(use_gpu=False) as sess:
- with self.assertRaisesRegexp(
- errors_impl.InvalidArgumentError,
- "Integers to negative integer powers are not allowed"):
- x = np.array([5, 2]).astype(dtype)
- y = np.array([2, -3]).astype(dtype)
- sess.run(math_ops.pow(x, y))
-
- with self.test_session(use_gpu=False) as sess:
- with self.assertRaisesRegexp(
- errors_impl.InvalidArgumentError,
- "Integers to negative integer powers are not allowed"):
- x = np.array([5, 2]).astype(dtype)
- y = -3
- sess.run(math_ops.pow(x, y))
-
-
class ComparisonOpTest(test.TestCase):
def _compareScalar(self, func, x, y, dtype):
diff --git a/tensorflow/python/kernel_tests/cwise_ops_unary_test.py b/tensorflow/python/kernel_tests/cwise_ops_unary_test.py
new file mode 100644
index 0000000000..77f182784e
--- /dev/null
+++ b/tensorflow/python/kernel_tests/cwise_ops_unary_test.py
@@ -0,0 +1,541 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for unary coefficient-wise operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+_NEG = lambda x: -x
+_ABS = abs
+
+
+# TODO(zongheng): it'd be great to factor out this function and various random
+# SparseTensor gen funcs.
+def _sparsify(x, thresh=0.5, index_dtype=np.int64):
+ x[x < thresh] = 0
+
+ non_zero = np.where(x)
+ x_indices = np.vstack(non_zero).astype(index_dtype).T
+ x_values = x[non_zero]
+ x_shape = x.shape
+
+ return sparse_tensor.SparseTensor(
+ indices=x_indices, values=x_values, dense_shape=x_shape), x_values
+
+
+def _default_tolerance(dtype):
+ """Returns a sensible default tolerance for comparing results of a given type.
+
+ Args:
+ dtype: A datatype.
+ """
+ if dtype == np.float16:
+ return 5e-3
+ elif dtype in (np.float32, np.complex64):
+ return 1e-3
+ elif dtype in (np.float64, np.complex128):
+ return 1e-5
+ else:
+ return None # Fail fast for unexpected types
+
+
+class UnaryOpTest(test.TestCase):
+
+ def _compareCpu(self, x, np_func, tf_func, grad_rtol=None, grad_atol=None):
+ if grad_rtol is None:
+ grad_rtol = _default_tolerance(x.dtype)
+ if grad_atol is None:
+ grad_atol = _default_tolerance(x.dtype)
+ np_ans = np_func(x)
+ with self.test_session(use_gpu=False):
+ inx = ops.convert_to_tensor(x)
+ if x.dtype in (np.float32, np.float64,
+ dtypes_lib.bfloat16.as_numpy_dtype):
+ y = 1.1 * tf_func(inx)
+ np_ans *= 1.1
+ else:
+ y = tf_func(inx)
+ tf_cpu = y.eval()
+ self.assertShapeEqual(np_ans, y)
+ if x.dtype == np.float16:
+ self.assertAllClose(np_ans, tf_cpu, rtol=1e-3, atol=1e-3)
+ elif x.dtype == dtypes_lib.bfloat16.as_numpy_dtype:
+ self.assertAllClose(np_ans, tf_cpu, rtol=1e-2, atol=1e-2)
+ else:
+ self.assertAllClose(np_ans, tf_cpu)
+
+ if x.dtype in (np.complex64, np.complex128) and tf_func == math_ops.sign:
+ return # Return early
+
+ if x.dtype == np.float16:
+ s = list(np.shape(x))
+ jacob_t, _ = gradient_checker.compute_gradient(
+ inx, s, y, s, x_init_value=x)
+ xf = x.astype(np.float)
+ inxf = ops.convert_to_tensor(xf)
+ yf = tf_func(inxf)
+ _, jacob_n = gradient_checker.compute_gradient(
+ inxf, s, yf, s, x_init_value=xf, delta=1e-2)
+ jacob_n = jacob_n.astype(np.float16)
+ self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
+ elif x.dtype in (np.float32, np.complex64):
+ s = list(np.shape(x))
+ jacob_t, jacob_n = gradient_checker.compute_gradient(
+ inx, s, y, s, x_init_value=x, delta=1e-3)
+ self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
+ elif x.dtype in (np.float64, np.complex128):
+ s = list(np.shape(x))
+ jacob_t, jacob_n = gradient_checker.compute_gradient(
+ inx, s, y, s, x_init_value=x, delta=1e-5)
+ self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
+
+ def _check(self, result_tensor, result_np, input_sp_t, tol):
+ self.assertTrue(isinstance(result_tensor, sparse_tensor.SparseTensor))
+ self.assertTrue(isinstance(input_sp_t, sparse_tensor.SparseTensor))
+ self.assertAllEqual(input_sp_t.indices.eval(), result_tensor.indices.eval())
+ self.assertAllEqual(input_sp_t.dense_shape.eval(),
+ result_tensor.dense_shape.eval())
+ if tol is None:
+ self.assertAllClose(result_np, result_tensor.values.eval())
+ else:
+ self.assertAllClose(
+ result_np, result_tensor.values.eval(), rtol=tol, atol=tol)
+
+ def _compareSparseCpu(self, x, np_func, tf_func, tol):
+ x_sp, x_sp_vals = _sparsify(x)
+ res_np = np_func(x_sp_vals)
+ with self.test_session(use_gpu=False):
+ self._check(tf_func(x_sp), res_np, x_sp, tol)
+
+ def _compareGpu(self, x, np_func, tf_func):
+ np_ans = np_func(x)
+ with self.test_session(force_gpu=test_util.is_gpu_available()):
+ result = tf_func(ops.convert_to_tensor(x))
+ tf_gpu = result.eval()
+ if x.dtype == np.float16:
+ self.assertAllClose(np_ans, tf_gpu, rtol=1e-3, atol=1e-3)
+ else:
+ self.assertAllClose(np_ans, tf_gpu)
+ # TODO(zhifengc/ke): make gradient checker work on GPU.
+
+ def _compareSparseGpu(self, x, np_func, tf_func, tol):
+ x_sp, x_sp_vals = _sparsify(x)
+ res_np = np_func(x_sp_vals)
+ with self.test_session(force_gpu=test_util.is_gpu_available()):
+ self._check(tf_func(x_sp), res_np, x_sp, tol)
+
+ def _compareBoth(self, x, np_func, tf_func):
+ self._compareCpu(x, np_func, tf_func)
+ self._compareGpu(x, np_func, tf_func)
+
+ def _compareBothSparse(self, x, np_func, tf_func, tol=None):
+ self._compareSparseCpu(x, np_func, tf_func, tol)
+ self._compareSparseGpu(x, np_func, tf_func, tol)
+
+ def _inv(self, x):
+ return 1.0 / x
+
+ def _rsqrt(self, x):
+ return self._inv(np.sqrt(x))
+
+ def _sigmoid(self, x):
+ return 1.0 / (1.0 + np.exp(-x))
+
+ def _log_sigmoid(self, x):
+ return np.log(self._sigmoid(x))
+
+ def _replace_domain_error_with_inf(self, fn):
+
+ def func(x):
+ try:
+ return fn(x)
+ except ValueError as e:
+ if "domain error" in str(e):
+ return np.inf * np.ones_like(x)
+ else:
+ raise e
+
+ return func
+
+ def testFloatBasic(self):
+ x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32)
+ w = x - x.min() + 1.02 # all greater than 1
+ y = (x + .5).astype(np.float32) # no zero
+ z = (x + 15.5).astype(np.float32) # all positive
+ k = np.arange(-0.90, 0.90, 0.25).astype(np.float32) # between -1 and 1
+
+ self._compareBoth(x, np.abs, math_ops.abs)
+ self._compareBoth(x, np.abs, _ABS)
+ self._compareBoth(x, np.negative, math_ops.negative)
+ self._compareBoth(x, np.negative, _NEG)
+ self._compareBoth(y, self._inv, math_ops.reciprocal)
+ self._compareBoth(x, np.square, math_ops.square)
+ self._compareBoth(z, np.sqrt, math_ops.sqrt)
+ self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
+ self._compareBoth(x, np.exp, math_ops.exp)
+ self._compareBoth(x, np.expm1, math_ops.expm1)
+ self._compareBoth(z, np.log, math_ops.log)
+ self._compareBoth(z, np.log1p, math_ops.log1p)
+ self._compareBoth(x, np.sinh, math_ops.sinh)
+ self._compareBoth(x, np.cosh, math_ops.cosh)
+ self._compareBoth(x, np.tanh, math_ops.tanh)
+ self._compareBoth(x, np.arcsinh, math_ops.asinh)
+ self._compareBoth(w, np.arccosh, math_ops.acosh)
+ self._compareBoth(k, np.arctanh, math_ops.atanh)
+ self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
+ self._compareBoth(x, self._log_sigmoid, math_ops.log_sigmoid)
+ self._compareBoth(y, np.sign, math_ops.sign)
+ self._compareBoth(x, np.sin, math_ops.sin)
+ self._compareBoth(x, np.cos, math_ops.cos)
+ self._compareBoth(k, np.arcsin, math_ops.asin)
+ self._compareBoth(k, np.arccos, math_ops.acos)
+ self._compareBoth(x, np.arctan, math_ops.atan)
+ self._compareBoth(x, np.tan, math_ops.tan)
+ self._compareBoth(
+ y, np.vectorize(self._replace_domain_error_with_inf(math.lgamma)),
+ math_ops.lgamma)
+ self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
+ self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
+ try:
+ from scipy import special # pylint: disable=g-import-not-at-top
+ self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
+ self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
+ except ImportError as e:
+ tf_logging.warn("Cannot test special functions: %s" % str(e))
+
+ self._compareBothSparse(x, np.abs, math_ops.abs)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
+ self._compareBothSparse(x, np.square, math_ops.square)
+ self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
+ self._compareBothSparse(x, np.tanh, math_ops.tanh)
+ self._compareBothSparse(y, np.sign, math_ops.sign)
+ self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf)
+
+ def testFloatTanhEdge(self):
+ x = np.arange(40, 40 + 6).reshape(6).astype(np.float32)
+ self._compareBoth(x, np.tanh, math_ops.tanh)
+ x = np.arange(-40, -40 + 6).reshape(6).astype(np.float32)
+ self._compareBoth(x, np.tanh, math_ops.tanh)
+
+ def testFloatEmpty(self):
+ x = np.empty((2, 0, 5), dtype=np.float32)
+ self._compareBoth(x, np.abs, math_ops.abs)
+ self._compareBoth(x, np.abs, _ABS)
+ self._compareBoth(x, np.negative, math_ops.negative)
+ self._compareBoth(x, np.negative, _NEG)
+ self._compareBoth(x, self._inv, math_ops.reciprocal)
+ self._compareBoth(x, np.square, math_ops.square)
+ self._compareBoth(x, np.sqrt, math_ops.sqrt)
+ self._compareBoth(x, self._rsqrt, math_ops.rsqrt)
+ self._compareBoth(x, np.exp, math_ops.exp)
+ self._compareBoth(x, np.expm1, math_ops.expm1)
+ self._compareBoth(x, np.log, math_ops.log)
+ self._compareBoth(x, np.log1p, math_ops.log1p)
+ self._compareBoth(x, np.sinh, math_ops.sinh)
+ self._compareBoth(x, np.arcsinh, math_ops.asinh)
+ self._compareBoth(x, np.cosh, math_ops.cosh)
+ self._compareBoth(x, np.tanh, math_ops.tanh)
+ self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
+ self._compareBoth(x, np.sign, math_ops.sign)
+ self._compareBoth(x, np.sin, math_ops.sin)
+ self._compareBoth(x, np.cos, math_ops.cos)
+ # Can't use vectorize below, so just use some arbitrary function
+ self._compareBoth(x, np.sign, math_ops.lgamma)
+ self._compareBoth(x, np.sign, math_ops.erf)
+ self._compareBoth(x, np.sign, math_ops.erfc)
+ self._compareBoth(x, np.tan, math_ops.tan)
+ self._compareBoth(x, np.arcsin, math_ops.asin)
+ self._compareBoth(x, np.arccos, math_ops.acos)
+ self._compareBoth(x, np.arctan, math_ops.atan)
+ try:
+ from scipy import special # pylint: disable=g-import-not-at-top
+ self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
+ self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
+ except ImportError as e:
+ tf_logging.warn("Cannot test special functions: %s" % str(e))
+
+ self._compareBothSparse(x, np.abs, math_ops.abs)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
+ self._compareBothSparse(x, np.square, math_ops.square)
+ self._compareBothSparse(x, np.sqrt, math_ops.sqrt, tol=1e-3)
+ self._compareBothSparse(x, np.tanh, math_ops.tanh)
+ self._compareBothSparse(x, np.sign, math_ops.sign)
+ self._compareBothSparse(x, np.sign, math_ops.erf)
+
+ def testDoubleBasic(self):
+ x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
+ w = x - x.min() + 1.02 # all greater than 1
+ y = (x + .5).astype(np.float64) # no zero
+ z = (x + 15.5).astype(np.float64) # all positive
+ k = np.arange(-0.90, 0.90,
+ 0.35).reshape(1, 3, 2).astype(np.float64) # between -1 and 1
+ self._compareBoth(x, np.abs, math_ops.abs)
+ self._compareBoth(x, np.abs, _ABS)
+ self._compareBoth(x, np.negative, math_ops.negative)
+ self._compareBoth(x, np.negative, _NEG)
+ self._compareBoth(y, self._inv, math_ops.reciprocal)
+ self._compareBoth(x, np.square, math_ops.square)
+ self._compareBoth(z, np.sqrt, math_ops.sqrt)
+ self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
+ self._compareBoth(x, np.exp, math_ops.exp)
+ self._compareBoth(x, np.expm1, math_ops.expm1)
+ self._compareBoth(z, np.log, math_ops.log)
+ self._compareBoth(z, np.log1p, math_ops.log1p)
+ self._compareBoth(x, np.sinh, math_ops.sinh)
+ self._compareBoth(x, np.cosh, math_ops.cosh)
+ self._compareBoth(x, np.tanh, math_ops.tanh)
+ self._compareBoth(x, np.arcsinh, math_ops.asinh)
+ self._compareBoth(w, np.arccosh, math_ops.acosh)
+ self._compareBoth(k, np.arctanh, math_ops.atanh)
+ self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
+ self._compareBoth(y, np.sign, math_ops.sign)
+ self._compareBoth(x, np.sin, math_ops.sin)
+ self._compareBoth(x, np.cos, math_ops.cos)
+ self._compareBoth(
+ y, np.vectorize(self._replace_domain_error_with_inf(math.lgamma)),
+ math_ops.lgamma)
+ self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
+ self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
+ self._compareBoth(x, np.arctan, math_ops.atan)
+ self._compareBoth(k, np.arcsin, math_ops.asin)
+ self._compareBoth(k, np.arccos, math_ops.acos)
+ self._compareBoth(k, np.tan, math_ops.tan)
+ try:
+ from scipy import special # pylint: disable=g-import-not-at-top
+ self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
+ self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
+ except ImportError as e:
+ tf_logging.warn("Cannot test special functions: %s" % str(e))
+
+ self._compareBothSparse(x, np.abs, math_ops.abs)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
+ self._compareBothSparse(x, np.square, math_ops.square)
+ self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
+ self._compareBothSparse(x, np.tanh, math_ops.tanh)
+ self._compareBothSparse(y, np.sign, math_ops.sign)
+ self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf)
+
+ def testHalfBasic(self):
+ x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float16)
+ y = (x + .5).astype(np.float16) # no zero
+ z = (x + 15.5).astype(np.float16) # all positive
+ self._compareBoth(x, np.abs, math_ops.abs)
+ self._compareBoth(x, np.abs, _ABS)
+ self._compareBoth(x, np.negative, math_ops.negative)
+ self._compareBoth(x, np.negative, _NEG)
+ self._compareBoth(y, self._inv, math_ops.reciprocal)
+ self._compareBoth(x, np.square, math_ops.square)
+ self._compareBoth(z, np.sqrt, math_ops.sqrt)
+ self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
+ self._compareBoth(x, np.exp, math_ops.exp)
+ self._compareBoth(x, np.expm1, math_ops.expm1)
+ self._compareBoth(z, np.log, math_ops.log)
+ self._compareBoth(z, np.log1p, math_ops.log1p)
+ self._compareBoth(x, np.tanh, math_ops.tanh)
+ self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
+ self._compareBoth(y, np.sign, math_ops.sign)
+ self._compareBoth(x, np.sin, math_ops.sin)
+ self._compareBoth(x, np.cos, math_ops.cos)
+ self._compareBoth(
+ y, np.vectorize(self._replace_domain_error_with_inf(math.lgamma)),
+ math_ops.lgamma)
+ self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
+ self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
+ try:
+ from scipy import special # pylint: disable=g-import-not-at-top
+ self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
+ self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
+ except ImportError as e:
+ tf_logging.warn("Cannot test special functions: %s" % str(e))
+
+ self._compareBothSparse(x, np.abs, math_ops.abs)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
+ self._compareBothSparse(x, np.square, math_ops.square)
+ self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
+ self._compareBothSparse(x, np.tanh, math_ops.tanh)
+ self._compareBothSparse(y, np.sign, math_ops.sign)
+ self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf, tol=1e-3)
+
+ def testInt32Basic(self):
+ x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32)
+ self._compareCpu(x, np.abs, math_ops.abs)
+ self._compareCpu(x, np.abs, _ABS)
+ self._compareBoth(x, np.negative, math_ops.negative)
+ self._compareBoth(x, np.negative, _NEG)
+ self._compareBoth(x, np.square, math_ops.square)
+ self._compareCpu(x, np.sign, math_ops.sign)
+
+ self._compareBothSparse(x, np.abs, math_ops.abs)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
+ self._compareBothSparse(x, np.square, math_ops.square)
+ self._compareBothSparse(x, np.sign, math_ops.sign)
+
+ def testInt64Basic(self):
+ x = np.arange(-6 << 40, 6 << 40, 2 << 40).reshape(1, 3, 2).astype(np.int64)
+ self._compareCpu(x, np.abs, math_ops.abs)
+ self._compareCpu(x, np.abs, _ABS)
+ self._compareCpu(x, np.negative, math_ops.negative)
+ self._compareCpu(x, np.negative, _NEG)
+ self._compareCpu(x, np.sign, math_ops.sign)
+
+ self._compareBothSparse(x, np.abs, math_ops.abs)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
+ self._compareBothSparse(x, np.sign, math_ops.sign)
+
+ def testInt64Square(self):
+ x = np.arange(-6 << 20, 6 << 20, 2 << 20).reshape(1, 3, 2).astype(np.int64)
+ self._compareCpu(x, np.square, math_ops.square)
+ self._compareBothSparse(x, np.square, math_ops.square)
+
+ def testComplex64Basic(self):
+ x = np.complex(1, 1) * np.arange(-3, 3).reshape(1, 3, 2).astype(
+ np.complex64)
+ y = x + np.complex(0.5, 0.5) # no zeros
+ self._compareBoth(x, np.abs, math_ops.abs)
+ self._compareBoth(x, np.abs, _ABS)
+ self._compareBoth(x, np.negative, math_ops.negative)
+ self._compareBoth(x, np.negative, _NEG)
+ self._compareCpu(y, self._inv, math_ops.reciprocal)
+ self._compareCpu(x, np.square, math_ops.square)
+ self._compareCpu(y, np.sqrt, math_ops.sqrt)
+ self._compareCpu(y, self._rsqrt, math_ops.rsqrt)
+ self._compareBoth(x, np.exp, math_ops.exp)
+ self._compareCpu(x, np.expm1, math_ops.expm1)
+ self._compareCpu(y, np.log, math_ops.log)
+ self._compareCpu(y, np.log1p, math_ops.log1p)
+ self._compareCpu(x, np.sinh, math_ops.sinh)
+ self._compareCpu(x, np.cosh, math_ops.cosh)
+ self._compareCpu(x, np.tanh, math_ops.tanh)
+
+ # Complex64 versions of asinh() and acosh() in libstdc++ only have 6 digits
+ # of precision.
+ # Small gradient values + low precision --> High relative error
+ self._compareCpu(y, np.arcsinh, math_ops.asinh, grad_rtol=1e-2)
+ self._compareCpu(y, np.arccosh, math_ops.acosh, grad_rtol=1e-2)
+
+ self._compareCpu(y, np.arctanh, math_ops.atanh)
+ self._compareCpu(x, self._sigmoid, math_ops.sigmoid)
+ self._compareCpu(x, np.sin, math_ops.sin)
+ self._compareCpu(x, np.cos, math_ops.cos)
+
+ self._compareBothSparse(x, np.abs, math_ops.abs)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
+ self._compareBothSparse(x, np.square, math_ops.square)
+ self._compareBothSparse(x, np.sqrt, math_ops.sqrt, 1e-3)
+ self._compareBothSparse(x, np.tanh, math_ops.tanh)
+
+ # Numpy uses an incorrect definition of sign; use the right one instead.
+ def complex_sign(x):
+ return x / np.abs(x)
+
+ self._compareBoth(y, complex_sign, math_ops.sign)
+ self._compareBothSparse(y, complex_sign, math_ops.sign)
+
+ def testComplex128Basic(self):
+ x = np.complex(1, 1) * np.arange(-3, 3).reshape(1, 3, 2).astype(
+ np.complex128)
+ y = x + np.complex(0.5, 0.5) # no zeros
+ self._compareBoth(x, np.abs, math_ops.abs)
+ self._compareBoth(x, np.abs, _ABS)
+ self._compareBoth(x, np.negative, math_ops.negative)
+ self._compareBoth(x, np.negative, _NEG)
+ self._compareCpu(y, self._inv, math_ops.reciprocal)
+ self._compareCpu(x, np.square, math_ops.square)
+ self._compareCpu(y, np.sqrt, math_ops.sqrt)
+ self._compareCpu(y, self._rsqrt, math_ops.rsqrt)
+ self._compareBoth(x, np.exp, math_ops.exp)
+ self._compareCpu(x, np.expm1, math_ops.expm1)
+ self._compareCpu(y, np.log, math_ops.log)
+ self._compareCpu(y, np.log1p, math_ops.log1p)
+ self._compareCpu(x, np.sinh, math_ops.sinh)
+ self._compareCpu(x, np.cosh, math_ops.cosh)
+ self._compareCpu(x, np.tanh, math_ops.tanh)
+ self._compareCpu(y, np.arcsinh, math_ops.asinh)
+ self._compareCpu(y, np.arccosh, math_ops.acosh)
+ self._compareCpu(y, np.arctanh, math_ops.atanh)
+ self._compareCpu(x, self._sigmoid, math_ops.sigmoid)
+ self._compareCpu(x, np.sin, math_ops.sin)
+ self._compareCpu(x, np.cos, math_ops.cos)
+
+ self._compareBothSparse(x, np.abs, math_ops.abs)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
+ self._compareBothSparse(x, np.square, math_ops.square)
+ self._compareBothSparse(x, np.sqrt, math_ops.sqrt, 1e-3)
+ self._compareBothSparse(x, np.tanh, math_ops.tanh)
+
+ # Numpy uses an incorrect definition of sign; use the right one instead.
+ def complex_sign(x):
+ return x / np.abs(x)
+
+ self._compareBoth(y, complex_sign, math_ops.sign)
+ self._compareBothSparse(y, complex_sign, math_ops.sign)
+
+ def testGradGrad(self):
+ np.random.seed(7)
+ shape = (5,)
+ dtype_tols = [(np.float32, 5e-4), (np.float64, 1e-6), (np.complex64, 5e-4),
+ (np.complex128, 1e-6)]
+ op_range = [
+ (gen_math_ops.reciprocal_grad, [-2, 2]),
+ (gen_math_ops.rsqrt_grad, [0.1, 3]),
+ (gen_math_ops.sigmoid_grad, [-2, 2]),
+ (gen_math_ops.sqrt_grad, [0.1, 3]),
+ (gen_math_ops.tanh_grad, [-2, 2]),
+ ]
+
+ def rand(dtype, real_range):
+ x = np.random.uniform(
+ real_range[0], real_range[1], size=shape[0]).astype(dtype)
+ if dtype in (np.complex64, np.complex128):
+ x += 1j * np.random.uniform(-2, 2, size=shape[0]).astype(dtype)
+ return x
+
+ for op, real_range in op_range:
+ with self.cached_session():
+ for dtype, tol in dtype_tols:
+ x = constant_op.constant(rand(dtype, real_range))
+ y = constant_op.constant(rand(dtype, real_range))
+ z = op(x, y)
+ grads = gradient_checker.compute_gradient(
+ [x, y], [shape, shape],
+ z,
+ shape,
+ x_init_value=[rand(dtype, real_range),
+ rand(dtype, real_range)])
+ if isinstance(grads, tuple):
+ grads = [grads]
+ for analytical, numerical in grads:
+ self.assertAllClose(analytical, numerical, rtol=tol, atol=tol)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/dense_update_ops_test.py b/tensorflow/python/kernel_tests/dense_update_ops_test.py
index 06c3271850..120e10314f 100644
--- a/tensorflow/python/kernel_tests/dense_update_ops_test.py
+++ b/tensorflow/python/kernel_tests/dense_update_ops_test.py
@@ -87,7 +87,7 @@ class AssignOpTest(test.TestCase):
def testAssignNonStrictShapeChecking(self):
with self.cached_session():
data = array_ops.fill([1024, 1024], 0)
- p = variables.Variable([1])
+ p = variables.VariableV1([1])
a = state_ops.assign(p, data, validate_shape=False)
a.op.run()
self.assertAllEqual(p.eval(), data.eval())
@@ -100,14 +100,14 @@ class AssignOpTest(test.TestCase):
def testInitRequiredAssignAdd(self):
with self.cached_session():
- p = variables.Variable(array_ops.fill([1024, 1024], 1), dtypes.int32)
+ p = variables.VariableV1(array_ops.fill([1024, 1024], 1), dtypes.int32)
a = state_ops.assign_add(p, array_ops.fill([1024, 1024], 0))
with self.assertRaisesOpError("use uninitialized"):
a.op.run()
def testInitRequiredAssignSub(self):
with self.cached_session():
- p = variables.Variable(array_ops.fill([1024, 1024], 1), dtypes.int32)
+ p = variables.VariableV1(array_ops.fill([1024, 1024], 1), dtypes.int32)
a = state_ops.assign_sub(p, array_ops.fill([1024, 1024], 0))
with self.assertRaisesOpError("use uninitialized"):
a.op.run()
diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
index 5741f2ec64..6d1ead20be 100644
--- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
+++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
@@ -128,7 +128,7 @@ class DepthwiseConv2DTest(test.TestCase):
x2 = [f * 1.0 / filter_size for f in range(1, filter_size + 1)]
ops.reset_default_graph()
graph = ops.get_default_graph()
- with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
+ with self.session(graph=graph, use_gpu=use_gpu) as sess:
tolerance = {
dtypes.float16: 4e-2,
dtypes.float32: 1e-8,
@@ -191,7 +191,7 @@ class DepthwiseConv2DTest(test.TestCase):
tf_logging.info(
"Testing DepthwiseConv2D, %dth config: %r * %r, stride: %d, padding: "
"%s", index, input_size, filter_size, stride, padding)
- for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ for data_type in [dtypes.float32, dtypes.float64]:
tf_logging.info("Testing without grouped_conv")
self._VerifyValues(
input_size, filter_size, stride, padding, data_type, use_gpu=True)
@@ -227,7 +227,7 @@ class DepthwiseConv2DTest(test.TestCase):
tf_logging.info(
"Testing DepthwiseConv2DFormat, %dth config: %r * %r, stride: %d, "
"padding: %s", index, input_size, filter_size, stride, padding)
- for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ for data_type in [dtypes.float32, dtypes.float64]:
self._VerifyValues(
input_size,
filter_size,
@@ -366,7 +366,7 @@ class DepthwiseConv2DTest(test.TestCase):
filter_data = [x * 1.0 / filter_size for x in range(0, filter_size)]
ops.reset_default_graph()
graph = ops.get_default_graph()
- with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
+ with self.session(graph=graph, use_gpu=use_gpu) as sess:
tolerance = {
dtypes.float16: 4e-0,
dtypes.float32: 8e-4,
@@ -434,7 +434,7 @@ class DepthwiseConv2DTest(test.TestCase):
tf_logging.info(
"Testing DepthwiseConv2DInputGrad, %dth config: %r * %r, stride: %d, "
"padding: %s", index, input_size, filter_size, stride, padding)
- for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ for data_type in [dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
@@ -465,7 +465,7 @@ class DepthwiseConv2DTest(test.TestCase):
"Testing DepthwiseConv2DInputGradFormat, %dth config: %r * %r, "
"stride: %d, padding: %s", index, input_size, filter_size, stride,
padding)
- for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ for data_type in [dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
@@ -483,7 +483,7 @@ class DepthwiseConv2DTest(test.TestCase):
tf_logging.info(
"Testing DepthwiseConv2DFilterGrad, %dth config: %r * %r, stride: "
"%d, padding: %s", index, input_size, filter_size, stride, padding)
- for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ for data_type in [dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
@@ -504,7 +504,7 @@ class DepthwiseConv2DTest(test.TestCase):
"Testing DepthwiseConv2DFilterGradFormat, %dth config: %r * %r, "
"stride: %d, padding: %s", index, input_size, filter_size, stride,
padding)
- for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ for data_type in [dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
diff --git a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
index 26d013bccb..37b35ba51a 100644
--- a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
+++ b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
@@ -118,7 +118,9 @@ class BernoulliTest(test.TestCase):
self.assertEqual(dist.probs.dtype, dist.stddev().dtype)
self.assertEqual(dist.probs.dtype, dist.entropy().dtype)
self.assertEqual(dist.probs.dtype, dist.prob(0).dtype)
+ self.assertEqual(dist.probs.dtype, dist.prob(0.5).dtype)
self.assertEqual(dist.probs.dtype, dist.log_prob(0).dtype)
+ self.assertEqual(dist.probs.dtype, dist.log_prob(0.5).dtype)
dist64 = make_bernoulli([], dtypes.int64)
self.assertEqual(dist64.dtype, dtypes.int64)
@@ -181,6 +183,16 @@ class BernoulliTest(test.TestCase):
return
self._testPmf(logits=special.logit(p))
+ @test_util.run_in_graph_and_eager_modes
+ def testPmfWithFloatArgReturnsXEntropy(self):
+ p = [[0.2], [0.4], [0.3], [0.6]]
+ samps = [0, 0.1, 0.8]
+ self.assertAllClose(
+ np.float32(samps) * np.log(np.float32(p)) +
+ (1 - np.float32(samps)) * np.log(1 - np.float32(p)),
+ self.evaluate(
+ bernoulli.Bernoulli(probs=p, validate_args=False).log_prob(samps)))
+
def testBroadcasting(self):
with self.cached_session():
p = array_ops.placeholder(dtypes.float32)
diff --git a/tensorflow/python/kernel_tests/distributions/normal_test.py b/tensorflow/python/kernel_tests/distributions/normal_test.py
index de73a40b23..6625a88843 100644
--- a/tensorflow/python/kernel_tests/distributions/normal_test.py
+++ b/tensorflow/python/kernel_tests/distributions/normal_test.py
@@ -78,6 +78,14 @@ class NormalTest(test.TestCase):
self.assertEqual(expected, sigma_shape)
@test_util.run_in_graph_and_eager_modes
+ def testSampleLikeArgsGetDistDType(self):
+ dist = normal_lib.Normal(0., 1.)
+ self.assertEqual(dtypes.float32, dist.dtype)
+ for method in ("log_prob", "prob", "log_cdf", "cdf",
+ "log_survival_function", "survival_function", "quantile"):
+ self.assertEqual(dtypes.float32, getattr(dist, method)(1).dtype)
+
+ @test_util.run_in_graph_and_eager_modes
def testParamShapes(self):
sample_shape = [10, 3, 4]
self._testParamShapes(sample_shape, sample_shape)
diff --git a/tensorflow/python/kernel_tests/extract_volume_patches_op_test.py b/tensorflow/python/kernel_tests/extract_volume_patches_op_test.py
new file mode 100644
index 0000000000..64757a3e07
--- /dev/null
+++ b/tensorflow/python/kernel_tests/extract_volume_patches_op_test.py
@@ -0,0 +1,131 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for ExtractVolumePatches op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+class ExtractVolumePatches(test.TestCase):
+ """Functional tests for ExtractVolumePatches op."""
+
+ def _VerifyValues(self, image, ksizes, strides, padding, patches):
+ """Tests input-output pairs for the ExtractVolumePatches op.
+
+ Args:
+ image: Input tensor with shape:
+ [batch, in_planes, in_rows, in_cols, depth].
+ ksizes: Patch size specified as: [ksize_planes, ksize_rows, ksize_cols].
+ strides: Output strides, specified as:
+ [stride_planes, stride_rows, stride_cols].
+ padding: Padding type.
+ patches: Expected output.
+
+ Note:
+ rates are not supported as of now.
+ """
+ ksizes = [1] + ksizes + [1]
+ strides = [1] + strides + [1]
+
+ with self.test_session(use_gpu=True):
+ out_tensor = array_ops.extract_volume_patches(
+ constant_op.constant(image),
+ ksizes=ksizes,
+ strides=strides,
+ padding=padding,
+ name="im2col_3d")
+ self.assertAllClose(patches, out_tensor.eval())
+
+ # pylint: disable=bad-whitespace
+ def testKsize1x1x1Stride1x1x1(self):
+ """Verifies that for 1x1x1 kernel the output equals the input."""
+ image = np.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]) + 1
+ patches = image
+ for padding in ["VALID", "SAME"]:
+ self._VerifyValues(
+ image,
+ ksizes=[1, 1, 1],
+ strides=[1, 1, 1],
+ padding=padding,
+ patches=patches)
+
+ def testKsize1x1x1Stride2x3x4(self):
+ """Test for 1x1x1 kernel and strides."""
+ image = np.arange(6 * 2 * 4 * 5 * 3).reshape([6, 2, 4, 5, 3]) + 1
+ patches = image[:, ::2, ::3, ::4, :]
+ for padding in ["VALID", "SAME"]:
+ self._VerifyValues(
+ image,
+ ksizes=[1, 1, 1],
+ strides=[2, 3, 4],
+ padding=padding,
+ patches=patches)
+
+ def testKsize1x1x2Stride2x2x3(self):
+ """Test for 1x1x2 kernel and strides."""
+ image = np.arange(45).reshape([1, 3, 3, 5, 1]) + 1
+ patches = np.array([[[[[ 1, 2],
+ [ 4, 5]],
+ [[11, 12],
+ [14, 15]]],
+ [[[31, 32],
+ [34, 35]],
+ [[41, 42],
+ [44, 45]]]]])
+ for padding in ["VALID", "SAME"]:
+ self._VerifyValues(
+ image,
+ ksizes=[1, 1, 2],
+ strides=[2, 2, 3],
+ padding=padding,
+ patches=patches)
+
+ def testKsize2x2x2Stride1x1x1Valid(self):
+ """Test for 2x2x2 kernel with VALID padding."""
+ image = np.arange(8).reshape([1, 2, 2, 2, 1]) + 1
+ patches = np.array([[[[[1, 2, 3, 4, 5, 6, 7, 8]]]]])
+ self._VerifyValues(
+ image,
+ ksizes=[2, 2, 2],
+ strides=[1, 1, 1],
+ padding="VALID",
+ patches=patches)
+
+ def testKsize2x2x2Stride1x1x1Same(self):
+ """Test for 2x2x2 kernel with SAME padding."""
+ image = np.arange(8).reshape([1, 2, 2, 2, 1]) + 1
+ patches = np.array([[[[[1, 2, 3, 4, 5, 6, 7, 8],
+ [2, 0, 4, 0, 6, 0, 8, 0]],
+ [[3, 4, 0, 0, 7, 8, 0, 0],
+ [4, 0, 0, 0, 8, 0, 0, 0]]],
+ [[[5, 6, 7, 8, 0, 0, 0, 0],
+ [6, 0, 8, 0, 0, 0, 0, 0]],
+ [[7, 8, 0, 0, 0, 0, 0, 0],
+ [8, 0, 0, 0, 0, 0, 0, 0]]]]])
+ self._VerifyValues(
+ image,
+ ksizes=[2, 2, 2],
+ strides=[1, 1, 1],
+ padding="SAME",
+ patches=patches)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index e39daf1371..30d11852c7 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -735,7 +735,7 @@ class FunctionalOpsTest(test.TestCase):
def Run(sess, n):
return sess.run(functional_ops.While([n, 0.], Cond, Body))[1]
- with self.test_session(graph=g, use_gpu=use_gpu) as sess:
+ with self.session(graph=g, use_gpu=use_gpu) as sess:
self.assertAllEqual(Run(sess, 20.), 210.)
self.assertAllEqual(Run(sess, 100.), 5050.)
@@ -765,7 +765,7 @@ class FunctionalOpsTest(test.TestCase):
fetch = outputs[1]
else:
fetch = "my_while:1"
- with self.test_session(graph=g, use_gpu=use_gpu) as sess:
+ with self.session(graph=g, use_gpu=use_gpu) as sess:
return sess.run(fetch)
self.assertAllEqual(Run(20., False), 210.)
@@ -793,7 +793,7 @@ class FunctionalOpsTest(test.TestCase):
def BodyReturnsTooManyArgs(n, x):
return n - 1, x + n, x
- with self.test_session(graph=g, use_gpu=use_gpu):
+ with self.session(graph=g, use_gpu=use_gpu):
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Expected a single scalar.*got 2 tensors."):
@@ -818,7 +818,7 @@ class FunctionalOpsTest(test.TestCase):
def Body(n, x):
return n - 1, x + n
- with self.test_session(graph=g, use_gpu=use_gpu) as sess:
+ with self.session(graph=g, use_gpu=use_gpu) as sess:
n = array_ops.placeholder(dtypes.float32)
_, result = functional_ops.While([n, 0.], Cond, Body)
c = constant_op.constant(37.)
@@ -831,7 +831,7 @@ class FunctionalOpsTest(test.TestCase):
def _tfSum(self, use_gpu, rewrite_with_while):
with ops.Graph().as_default() as g:
- with self.test_session(graph=g, use_gpu=use_gpu) as sess:
+ with self.session(graph=g, use_gpu=use_gpu) as sess:
@function.Defun(dtypes.int32, dtypes.float32)
def Body(n, x):
diff --git a/tensorflow/python/kernel_tests/identity_op_py_test.py b/tensorflow/python/kernel_tests/identity_op_py_test.py
index 37f9f716f8..88ea10c22a 100644
--- a/tensorflow/python/kernel_tests/identity_op_py_test.py
+++ b/tensorflow/python/kernel_tests/identity_op_py_test.py
@@ -61,7 +61,7 @@ class IdentityOpTest(test.TestCase):
def testRefIdentityShape(self):
with self.cached_session():
shape = [2, 3]
- tensor = variables.Variable(
+ tensor = variables.VariableV1(
constant_op.constant(
[[1, 2, 3], [6, 5, 4]], dtype=dtypes.int32))
self.assertEquals(shape, tensor.get_shape())
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index 79ce965242..292679e4b9 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -522,7 +522,7 @@ class LinSpaceTest(test.TestCase):
def _LinSpace(self, start, stop, num):
# NOTE(touts): Needs to pass a graph to get a new session each time.
with ops.Graph().as_default() as graph:
- with self.test_session(graph=graph, force_gpu=self.force_gpu):
+ with self.session(graph=graph, force_gpu=self.force_gpu):
tf_ans = math_ops.linspace(start, stop, num, name="linspace")
self.assertEqual([num], tf_ans.get_shape())
return tf_ans.eval()
@@ -606,7 +606,7 @@ class OrthogonalInitializerTest(test.TestCase):
def testInvalidShape(self):
init1 = init_ops.orthogonal_initializer()
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
self.assertRaises(ValueError, init1, shape=[5])
def testGain(self):
@@ -614,7 +614,7 @@ class OrthogonalInitializerTest(test.TestCase):
for dtype in [dtypes.float32, dtypes.float64]:
init1 = init_ops.orthogonal_initializer(seed=1, dtype=dtype)
init2 = init_ops.orthogonal_initializer(gain=3.14, seed=1, dtype=dtype)
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
t1 = init1(shape).eval()
t2 = init2(shape).eval()
return np.allclose(t1, t2 / 3.14, rtol=1e-15, atol=1e-15)
@@ -624,7 +624,7 @@ class OrthogonalInitializerTest(test.TestCase):
for shape in [(10, 10), (10, 9, 8), (100, 5, 5), (50, 40), (40, 50)]:
init = init_ops.orthogonal_initializer(dtype=dtype)
tol = 1e-5 if dtype == dtypes.float32 else 1e-12
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
# Check the shape
t = init(shape).eval()
self.assertAllEqual(shape, t.shape)
@@ -663,7 +663,7 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
def testInvalidShape(self):
init1 = init_ops.convolutional_delta_orthogonal()
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
self.assertRaises(ValueError, init1, shape=[3, 3, 6, 5])
def testGain(self):
@@ -672,7 +672,7 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
init1 = init_ops.convolutional_delta_orthogonal(seed=1, dtype=dtype)
init2 = init_ops.convolutional_delta_orthogonal(gain=3.14,
seed=1, dtype=dtype)
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
t1 = init1(shape).eval()
t2 = init2(shape).eval()
return np.allclose(t1, t2 / 3.14, rtol=1e-15, atol=1e-15)
@@ -763,7 +763,7 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
def testInvalidShape(self):
init1 = init_ops.convolutional_orthogonal_1d()
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
self.assertRaises(ValueError, init1, shape=[3, 6, 5])
def testGain(self):
@@ -772,7 +772,7 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
init1 = init_ops.convolutional_orthogonal_1d(seed=1, dtype=dtype)
init2 = init_ops.convolutional_orthogonal_1d(gain=3.14,
seed=1, dtype=dtype)
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
t1 = init1(shape).eval()
t2 = init2(shape).eval()
return np.allclose(t1, t2 / 3.14, rtol=1e-15, atol=1e-15)
@@ -877,7 +877,7 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
def testInvalidShape(self):
init1 = init_ops.convolutional_orthogonal_2d()
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
self.assertRaises(ValueError, init1, shape=[3, 3, 6, 5])
def testGain(self):
@@ -886,7 +886,7 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
init1 = init_ops.convolutional_orthogonal_2d(seed=1, dtype=dtype)
init2 = init_ops.convolutional_orthogonal_2d(gain=3.14,
seed=1, dtype=dtype)
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
t1 = init1(shape).eval()
t2 = init2(shape).eval()
return np.allclose(t1, t2 / 3.14, rtol=1e-15, atol=1e-15)
@@ -972,7 +972,7 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
def testInvalidShape(self):
init1 = init_ops.convolutional_orthogonal_3d()
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
self.assertRaises(ValueError, init1, shape=[3, 3, 3, 6, 5])
def testGain(self):
@@ -981,7 +981,7 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
init1 = init_ops.convolutional_orthogonal_3d(seed=1, dtype=dtype)
init2 = init_ops.convolutional_orthogonal_3d(gain=3.14,
seed=1, dtype=dtype)
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
t1 = init1(shape).eval()
t2 = init2(shape).eval()
return np.allclose(t1, t2 / 3.14, rtol=1e-15, atol=1e-15)
@@ -1080,7 +1080,7 @@ class IdentityInitializerTest(test.TestCase):
def testInvalidShape(self):
init = init_ops.identity_initializer()
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
self.assertRaises(ValueError, init, shape=[5, 7, 7])
self.assertRaises(ValueError, init, shape=[5])
self.assertRaises(ValueError, init, shape=[])
@@ -1088,7 +1088,7 @@ class IdentityInitializerTest(test.TestCase):
def testNonSquare(self):
init = init_ops.identity_initializer()
shape = (10, 5)
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
self.assertAllClose(init(shape).eval(), np.eye(*shape))
def testGain(self):
@@ -1096,16 +1096,16 @@ class IdentityInitializerTest(test.TestCase):
for dtype in [dtypes.float32, dtypes.float64]:
init_default = init_ops.identity_initializer(dtype=dtype)
init_custom = init_ops.identity_initializer(gain=0.9, dtype=dtype)
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
self.assertAllClose(init_default(shape).eval(), np.eye(*shape))
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
self.assertAllClose(init_custom(shape).eval(), np.eye(*shape) * 0.9)
def testPartitions(self):
shape = (10, 10)
init = init_ops.identity_initializer()
partitioner = partitioned_variables.variable_axis_size_partitioner(1)
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
with variable_scope.variable_scope(
"foo", partitioner=partitioner, initializer=init):
v = array_ops.identity(variable_scope.get_variable("bar", shape=shape))
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py
index 7c79fedf65..cf56168d63 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py
@@ -76,7 +76,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
[1., 1.], is_positive_definite=True, name="A")
op_b = linalg.LinearOperatorDiag(
[2., 2.], is_positive_definite=True, name="B")
- with self.test_session():
+ with self.cached_session():
op_sum = add_operators([op_a, op_b])
self.assertEqual(1, len(op_sum))
op = op_sum[0]
@@ -98,7 +98,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
[2., 2.], is_positive_definite=True, name="op2")
op3 = linalg.LinearOperatorDiag(
[3., 3.], is_positive_definite=True, name="op3")
- with self.test_session():
+ with self.cached_session():
op_sum = add_operators([op1, op2, op3])
self.assertEqual(1, len(op_sum))
op = op_sum[0]
@@ -121,7 +121,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
name="tril")
op3 = linalg.LinearOperatorDiag(
[3., 3.], is_non_singular=True, name="diag_b")
- with self.test_session():
+ with self.cached_session():
op_sum = add_operators([op1, op2, op3])
self.assertEqual(1, len(op_sum))
op = op_sum[0]
@@ -143,7 +143,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
op2 = linalg.LinearOperatorLowerTriangular(
[[2., 0.], [1.5, 2.]], name="tril")
op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b")
- with self.test_session():
+ with self.cached_session():
op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator")
self.assertEqual(1, len(op_sum))
op = op_sum[0]
@@ -233,7 +233,7 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase):
self.assertEqual(2, len(op_sum))
found_diag = False
found_tril = False
- with self.test_session():
+ with self.cached_session():
for op in op_sum:
if isinstance(op, linalg.LinearOperatorDiag):
found_diag = True
@@ -273,7 +273,7 @@ class AddAndReturnScaledIdentityTest(test.TestCase):
operator = self._adder.add(id1, id2, "my_operator", hints)
self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(2 *
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
operator.to_dense().eval())
@@ -291,7 +291,7 @@ class AddAndReturnScaledIdentityTest(test.TestCase):
operator = self._adder.add(id1, id2, "my_operator", hints)
self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(3.2 *
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
operator.to_dense().eval())
@@ -310,7 +310,7 @@ class AddAndReturnScaledIdentityTest(test.TestCase):
operator = self._adder.add(id1, id2, "my_operator", hints)
self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(1.2 *
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
operator.to_dense().eval())
@@ -334,7 +334,7 @@ class AddAndReturnDiagTest(test.TestCase):
operator = self._adder.add(id1, id2, "my_operator", hints)
self.assertIsInstance(operator, linalg.LinearOperatorDiag)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(2 *
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
operator.to_dense().eval())
@@ -354,7 +354,7 @@ class AddAndReturnDiagTest(test.TestCase):
operator = self._adder.add(op1, op2, "my_operator", hints)
self.assertIsInstance(operator, linalg.LinearOperatorDiag)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(
linalg.LinearOperatorDiag(diag1 + diag2).to_dense().eval(),
operator.to_dense().eval())
@@ -379,7 +379,7 @@ class AddAndReturnTriLTest(test.TestCase):
operator = self._adder.add(diag, tril, "my_operator", hints)
self.assertIsInstance(operator, linalg.LinearOperatorLowerTriangular)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval())
self.assertTrue(operator.is_positive_definite)
self.assertTrue(operator.is_non_singular)
@@ -401,7 +401,7 @@ class AddAndReturnMatrixTest(test.TestCase):
operator = self._adder.add(diag1, diag2, "my_operator", hints)
self.assertIsInstance(operator, linalg.LinearOperatorFullMatrix)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose([[0., 0.], [0., 5.]], operator.to_dense().eval())
self.assertFalse(operator.is_positive_definite)
self.assertFalse(operator.is_non_singular)
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
index 7261d4bb3b..f1e151ebd8 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
@@ -37,8 +37,10 @@ class LinearOperatorCirculantBaseTest(object):
"""Common class for circulant tests."""
@contextlib.contextmanager
- def test_session(self, *args, **kwargs):
- with test.TestCase.test_session(self, *args, **kwargs) as sess:
+ def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
+ """We overwrite the FFT operation mapping for testing."""
+ with test.TestCase._constrain_devices_and_set_default(
+ self, sess, use_gpu, force_gpu) as sess:
with spectral_ops_test_util.fft_kernel_label_map():
yield sess
@@ -110,8 +112,7 @@ class LinearOperatorCirculantTestSelfAdjointOperator(
lin_op_spectrum = spectrum
if use_placeholder:
- lin_op_spectrum = array_ops.placeholder_with_default(
- spectrum, shape=None)
+ lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
operator = linalg.LinearOperatorCirculant(
lin_op_spectrum, is_self_adjoint=True, input_output_dtype=dtype)
@@ -121,7 +122,7 @@ class LinearOperatorCirculantTestSelfAdjointOperator(
return operator, mat
def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
- with self.test_session():
+ with self.cached_session():
spectrum = math_ops.cast([1., 1j, -1j], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(
spectrum, input_output_dtype=dtypes.complex64)
@@ -171,8 +172,7 @@ class LinearOperatorCirculantTestHermitianSpectrum(
lin_op_spectrum = spectrum
if use_placeholder:
- lin_op_spectrum = array_ops.placeholder_with_default(
- spectrum, shape=None)
+ lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
operator = linalg.LinearOperatorCirculant(
lin_op_spectrum, input_output_dtype=dtype)
@@ -182,7 +182,7 @@ class LinearOperatorCirculantTestHermitianSpectrum(
return operator, mat
def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
- with self.test_session():
+ with self.cached_session():
spectrum = math_ops.cast([1., 1j, -1j], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(
spectrum, input_output_dtype=dtypes.complex64)
@@ -217,8 +217,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
lin_op_spectrum = spectrum
if use_placeholder:
- lin_op_spectrum = array_ops.placeholder_with_default(
- spectrum, shape=None)
+ lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
operator = linalg.LinearOperatorCirculant(
lin_op_spectrum, input_output_dtype=dtype)
@@ -228,7 +227,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
return operator, mat
def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
- with self.test_session():
+ with self.cached_session():
spectrum = math_ops.cast([1., 1j, -1j], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(
spectrum, input_output_dtype=dtypes.complex64)
@@ -238,7 +237,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
np.testing.assert_allclose(0, imag_matrix.eval(), rtol=0, atol=eps * 3)
def test_simple_positive_real_spectrum_gives_self_adjoint_pos_def_oper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
spectrum = math_ops.cast([6., 4, 2], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(
spectrum, input_output_dtype=dtypes.complex64)
@@ -250,7 +249,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
operator.assert_self_adjoint().run() # Should not fail
def test_defining_operator_using_real_convolution_kernel(self):
- with self.test_session():
+ with self.cached_session():
convolution_kernel = [1., 2., 1.]
spectrum = math_ops.fft(
math_ops.cast(convolution_kernel, dtypes.complex64))
@@ -266,7 +265,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
np.testing.assert_allclose(0, np.imag(matrix), atol=1e-6)
def test_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
- with self.test_session():
+ with self.cached_session():
# Make spectrum the FFT of a real convolution kernel h. This ensures that
# spectrum is Hermitian.
h = linear_operator_test_util.random_normal(shape=(3, 4))
@@ -281,7 +280,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
def test_convolution_kernel_same_as_first_row_of_to_dense(self):
spectrum = [[3., 2., 1.], [2., 1.5, 1.]]
- with self.test_session():
+ with self.cached_session():
operator = linalg.LinearOperatorCirculant(spectrum)
h = operator.convolution_kernel()
c = operator.to_dense()
@@ -293,27 +292,27 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
def test_assert_non_singular_fails_for_singular_operator(self):
spectrum = math_ops.cast([0, 4, 2j + 2], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(spectrum)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Singular operator"):
operator.assert_non_singular().run()
def test_assert_non_singular_does_not_fail_for_non_singular_operator(self):
spectrum = math_ops.cast([-3j, 4, 2j + 2], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(spectrum)
- with self.test_session():
+ with self.cached_session():
operator.assert_non_singular().run() # Should not fail
def test_assert_positive_definite_fails_for_non_positive_definite(self):
spectrum = math_ops.cast([6., 4, 2j], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(spectrum)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Not positive definite"):
operator.assert_positive_definite().run()
def test_assert_positive_definite_does_not_fail_when_pos_def(self):
spectrum = math_ops.cast([6., 4, 2j + 2], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(spectrum)
- with self.test_session():
+ with self.cached_session():
operator.assert_positive_definite().run() # Should not fail
def test_real_spectrum_and_not_self_adjoint_hint_raises(self):
@@ -331,8 +330,10 @@ class LinearOperatorCirculant2DBaseTest(object):
"""Common class for 2D circulant tests."""
@contextlib.contextmanager
- def test_session(self, *args, **kwargs):
- with test.TestCase.test_session(self, *args, **kwargs) as sess:
+ def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
+ """We overwrite the FFT operation mapping for testing."""
+ with test.TestCase._constrain_devices_and_set_default(
+ self, sess, use_gpu, force_gpu) as sess:
with spectral_ops_test_util.fft_kernel_label_map():
yield sess
@@ -446,8 +447,7 @@ class LinearOperatorCirculant2DTestHermitianSpectrum(
lin_op_spectrum = spectrum
if use_placeholder:
- lin_op_spectrum = array_ops.placeholder_with_default(
- spectrum, shape=None)
+ lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
operator = linalg.LinearOperatorCirculant2D(
lin_op_spectrum, input_output_dtype=dtype)
@@ -482,8 +482,7 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
lin_op_spectrum = spectrum
if use_placeholder:
- lin_op_spectrum = array_ops.placeholder_with_default(
- spectrum, shape=None)
+ lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
operator = linalg.LinearOperatorCirculant2D(
lin_op_spectrum, input_output_dtype=dtype)
@@ -493,7 +492,7 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
return operator, mat
def test_real_hermitian_spectrum_gives_real_symmetric_operator(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# This is a real and hermitian spectrum.
spectrum = [[1., 2., 2.], [3., 4., 4.], [3., 4., 4.]]
operator = linalg.LinearOperatorCirculant(spectrum)
@@ -510,7 +509,7 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
self.assertAllClose(matrix, matrix_transpose, atol=0)
def test_real_spectrum_gives_self_adjoint_operator(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# This is a real and hermitian spectrum.
spectrum = linear_operator_test_util.random_normal(
shape=(3, 3), dtype=dtypes.float32)
@@ -526,27 +525,27 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
def test_assert_non_singular_fails_for_singular_operator(self):
spectrum = math_ops.cast([[0, 4], [2j + 2, 3.]], dtypes.complex64)
operator = linalg.LinearOperatorCirculant2D(spectrum)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Singular operator"):
operator.assert_non_singular().run()
def test_assert_non_singular_does_not_fail_for_non_singular_operator(self):
spectrum = math_ops.cast([[-3j, 4], [2j + 2, 3.]], dtypes.complex64)
operator = linalg.LinearOperatorCirculant2D(spectrum)
- with self.test_session():
+ with self.cached_session():
operator.assert_non_singular().run() # Should not fail
def test_assert_positive_definite_fails_for_non_positive_definite(self):
spectrum = math_ops.cast([[6., 4], [2j, 3.]], dtypes.complex64)
operator = linalg.LinearOperatorCirculant2D(spectrum)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Not positive definite"):
operator.assert_positive_definite().run()
def test_assert_positive_definite_does_not_fail_when_pos_def(self):
spectrum = math_ops.cast([[6., 4], [2j + 2, 3.]], dtypes.complex64)
operator = linalg.LinearOperatorCirculant2D(spectrum)
- with self.test_session():
+ with self.cached_session():
operator.assert_positive_definite().run() # Should not fail
def test_real_spectrum_and_not_self_adjoint_hint_raises(self):
@@ -574,13 +573,15 @@ class LinearOperatorCirculant3DTest(test.TestCase):
"""Simple test of the 3D case. See also the 1D and 2D tests."""
@contextlib.contextmanager
- def test_session(self, *args, **kwargs):
- with test.TestCase.test_session(self, *args, **kwargs) as sess:
+ def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
+ """We overwrite the FFT operation mapping for testing."""
+ with test.TestCase._constrain_devices_and_set_default(
+ self, sess, use_gpu, force_gpu) as sess:
with spectral_ops_test_util.fft_kernel_label_map():
yield sess
def test_real_spectrum_gives_self_adjoint_operator(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# This is a real and hermitian spectrum.
spectrum = linear_operator_test_util.random_normal(
shape=(2, 2, 3, 5), dtype=dtypes.float32)
@@ -597,7 +598,7 @@ class LinearOperatorCirculant3DTest(test.TestCase):
self.assertAllClose(matrix, matrix_h)
def test_defining_operator_using_real_convolution_kernel(self):
- with self.test_session():
+ with self.cached_session():
convolution_kernel = linear_operator_test_util.random_normal(
shape=(2, 2, 3, 5), dtype=dtypes.float32)
# Convolution kernel is real ==> spectrum is Hermitian.
@@ -615,7 +616,7 @@ class LinearOperatorCirculant3DTest(test.TestCase):
np.testing.assert_allclose(0, np.imag(matrix), atol=1e-6)
def test_defining_spd_operator_by_taking_real_part(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# S is real and positive.
s = linear_operator_test_util.random_uniform(
shape=(10, 2, 3, 4), dtype=dtypes.float32, minval=1., maxval=2.)
diff --git a/tensorflow/python/kernel_tests/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg_grad_test.py
index cd6a34d657..e52f303fe0 100644
--- a/tensorflow/python/kernel_tests/linalg_grad_test.py
+++ b/tensorflow/python/kernel_tests/linalg_grad_test.py
@@ -120,7 +120,7 @@ def _GetMatrixBinaryFunctorGradientTest(functor_,
delta = epsilon**(1.0 / 3.0)
# tolerance obtained by looking at actual differences using
# np.linalg.norm(theoretical-numerical, np.inf) on -mavx build
- tol = 1e-6 if dtype_ == np.float64 else float32_tol_fudge * 0.04
+ tol = 1e-6 if dtype_ == np.float64 else float32_tol_fudge * 0.05
# The gradients for a and b may be of very different magnitudes,
# so to not get spurious failures we test them separately.
for factor, factor_init in [a, a_np], [b, b_np]:
diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py
index 0f5607712b..ae413edaec 100644
--- a/tensorflow/python/kernel_tests/list_ops_test.py
+++ b/tensorflow/python/kernel_tests/list_ops_test.py
@@ -170,6 +170,32 @@ class ListOpsTest(test_util.TensorFlowTestCase):
list_ops.tensor_list_pop_back(
l_cpu, element_dtype=dtypes.float32)[1]), 2.0)
+ @test_util.run_in_graph_and_eager_modes
+ def testCPUGPUCopyNested(self):
+ if not context.num_gpus():
+ return
+ t = constant_op.constant([1.0, 2.0])
+ child_l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape())
+ l = list_ops.empty_tensor_list(
+ element_shape=constant_op.constant([], dtype=dtypes.int32),
+ element_dtype=dtypes.variant)
+ l = list_ops.tensor_list_push_back(l, child_l)
+ with context.device("gpu:0"):
+ l_gpu = array_ops.identity(l)
+ _, child_l_gpu = list_ops.tensor_list_pop_back(
+ l_gpu, element_dtype=dtypes.variant)
+ self.assertAllEqual(
+ self.evaluate(
+ list_ops.tensor_list_pop_back(
+ child_l_gpu, element_dtype=dtypes.float32)[1]), 2.0)
+ l_cpu = array_ops.identity(l_gpu)
+ _, child_l_cpu = list_ops.tensor_list_pop_back(
+ l_cpu, element_dtype=dtypes.variant)
+ self.assertAllEqual(
+ self.evaluate(
+ list_ops.tensor_list_pop_back(
+ child_l_cpu, element_dtype=dtypes.float32)[1]), 2.0)
+
def testGraphStack(self):
with self.cached_session():
tl = list_ops.empty_tensor_list(
diff --git a/tensorflow/python/kernel_tests/logging_ops_logging_level_test.py b/tensorflow/python/kernel_tests/logging_ops_logging_level_test.py
new file mode 100644
index 0000000000..0e8197dccb
--- /dev/null
+++ b/tensorflow/python/kernel_tests/logging_ops_logging_level_test.py
@@ -0,0 +1,70 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.kernels.logging_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+
+class PrintV2LoggingLevelTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorLogInfo(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=tf_logging.info)
+ self.evaluate(print_op)
+ self.assertTrue("I" in printed.contents())
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue(expected in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorLogWarning(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=tf_logging.warning)
+ self.evaluate(print_op)
+ self.assertTrue("W" in printed.contents())
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue(expected in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorLogError(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=tf_logging.error)
+ self.evaluate(print_op)
+ self.assertTrue("E" in printed.contents())
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue(expected in printed.contents())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py
index 82729b9e27..4beddd00bb 100644
--- a/tensorflow/python/kernel_tests/logging_ops_test.py
+++ b/tensorflow/python/kernel_tests/logging_ops_test.py
@@ -18,13 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import sys
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -57,6 +65,269 @@ class LoggingOpsTest(test.TestCase):
out.eval()
+class PrintV2Test(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensor(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor)
+ self.evaluate(print_op)
+
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorVarySummarize(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, summarize=1)
+ self.evaluate(print_op)
+
+ expected = "[0 ... 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, summarize=2)
+ self.evaluate(print_op)
+
+ expected = "[0 1 ... 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, summarize=3)
+ self.evaluate(print_op)
+
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, summarize=-1)
+ self.evaluate(print_op)
+
+ expected = "[0 1 2 3 4 5 6 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneVariable(self):
+ with self.cached_session():
+ var = variables.Variable(math_ops.range(10))
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(var)
+ self.evaluate(print_op)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintTwoVariablesInStructWithAssignAdd(self):
+ with self.cached_session():
+ var_one = variables.Variable(2.14)
+ plus_one = var_one.assign_add(1.0)
+ var_two = variables.Variable(math_ops.range(10))
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ with self.captureWritesToStream(sys.stderr) as printed:
+ self.evaluate(plus_one)
+ print_op = logging_ops.print_v2(var_one, {"second": var_two})
+ self.evaluate(print_op)
+ expected = "3.14 {'second': [0 1 2 ... 7 8 9]}"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintTwoTensors(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, tensor * 10)
+ self.evaluate(print_op)
+ expected = "[0 1 2 ... 7 8 9] [0 10 20 ... 70 80 90]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintPlaceholderGeneration(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2("{}6", {"{}": tensor * 10})
+ self.evaluate(print_op)
+ expected = "{}6 {'{}': [0 10 20 ... 70 80 90]}"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintNoTensors(self):
+ with self.cached_session():
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(23, [23, 5], {"6": 12})
+ self.evaluate(print_op)
+ expected = "23 [23, 5] {'6': 12}"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintFloatScalar(self):
+ with self.cached_session():
+ tensor = ops.convert_to_tensor(434.43)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor)
+ self.evaluate(print_op)
+ expected = "434.43"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintStringScalar(self):
+ with self.cached_session():
+ tensor = ops.convert_to_tensor("scalar")
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor)
+ self.evaluate(print_op)
+ expected = "scalar"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintComplexTensorStruct(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ small_tensor = constant_op.constant([0.3, 12.4, -16.1])
+ big_tensor = math_ops.mul(tensor, 10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ "first:", tensor, "middle:",
+ {"small": small_tensor, "Big": big_tensor}, 10,
+ [tensor * 2, tensor])
+ self.evaluate(print_op)
+ # Note that the keys in the dict will always be sorted,
+ # so 'Big' comes before 'small'
+ expected = ("first: [0 1 2 ... 7 8 9] "
+ "middle: {'Big': [0 10 20 ... 70 80 90], "
+ "'small': [0.3 12.4 -16.1]} "
+ "10 [[0 2 4 ... 14 16 18], [0 1 2 ... 7 8 9]]")
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintSparseTensor(self):
+ with self.cached_session():
+ ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
+ val = [0, 10, 13, 4, 14, 32, 33]
+ shape = [5, 6]
+
+ sparse = sparse_tensor.SparseTensor(
+ constant_op.constant(ind, dtypes.int64),
+ constant_op.constant(val, dtypes.int64),
+ constant_op.constant(shape, dtypes.int64))
+
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(sparse)
+ self.evaluate(print_op)
+ expected = ("'SparseTensor(indices=[[0 0]\n"
+ " [1 0]\n"
+ " [1 3]\n"
+ " ...\n"
+ " [1 4]\n"
+ " [3 2]\n"
+ " [3 3]], values=[0 10 13 ... 14 32 33], shape=[5 6])'")
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintSparseTensorInDataStruct(self):
+ with self.cached_session():
+ ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
+ val = [0, 10, 13, 4, 14, 32, 33]
+ shape = [5, 6]
+
+ sparse = sparse_tensor.SparseTensor(
+ constant_op.constant(ind, dtypes.int64),
+ constant_op.constant(val, dtypes.int64),
+ constant_op.constant(shape, dtypes.int64))
+
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2([sparse])
+ self.evaluate(print_op)
+ expected = ("['SparseTensor(indices=[[0 0]\n"
+ " [1 0]\n"
+ " [1 3]\n"
+ " ...\n"
+ " [1 4]\n"
+ " [3 2]\n"
+ " [3 3]], values=[0 10 13 ... 14 32 33], shape=[5 6])']")
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorStdout(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stdout) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=sys.stdout)
+ self.evaluate(print_op)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testInvalidOutputStreamRaisesError(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.assertRaises(ValueError):
+ print_op = logging_ops.print_v2(
+ tensor, output_stream="unknown")
+ self.evaluate(print_op)
+
+ def testPrintOpName(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ print_op = logging_ops.print_v2(tensor, name="print_name")
+ self.assertEqual(print_op.name, "print_name")
+
+ def testNoDuplicateFormatOpGraphModeAfterExplicitFormat(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ formatted_string = string_ops.string_format("{}", tensor)
+ print_op = logging_ops.print_v2(formatted_string)
+ self.evaluate(print_op)
+ graph_ops = ops.get_default_graph().get_operations()
+ format_ops = [op for op in graph_ops if op.type == "StringFormat"]
+ # Should be only 1 format_op for graph mode.
+ self.assertEqual(len(format_ops), 1)
+
+ def testPrintOneTensorEagerOnOpCreate(self):
+ with self.cached_session():
+ with context.eager_mode():
+ tensor = math_ops.range(10)
+ expected = "[0 1 2 ... 7 8 9]"
+ with self.captureWritesToStream(sys.stderr) as printed:
+ logging_ops.print_v2(tensor)
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintInDefunWithoutExplicitEvalOfPrint(self):
+ @function.defun
+ def f():
+ tensor = math_ops.range(10)
+ logging_ops.print_v2(tensor)
+ return tensor
+
+ expected = "[0 1 2 ... 7 8 9]"
+ with self.captureWritesToStream(sys.stderr) as printed_one:
+ x = f()
+ self.evaluate(x)
+ self.assertTrue((expected + "\n") in printed_one.contents())
+
+ # We execute the function again to make sure it doesn't only print on the
+ # first call.
+ with self.captureWritesToStream(sys.stderr) as printed_two:
+ y = f()
+ self.evaluate(y)
+ self.assertTrue((expected + "\n") in printed_two.contents())
+
+
class PrintGradientTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
@@ -65,6 +336,11 @@ class PrintGradientTest(test.TestCase):
inp_printed = logging_ops.Print(inp, [inp])
self.assertEqual(inp.get_shape(), inp_printed.get_shape())
+ def testPrintString(self):
+ inp = constant_op.constant(2.0, shape=[100, 32])
+ inp_printed = logging_ops.Print(inp, ["hello"])
+ self.assertEqual(inp.get_shape(), inp_printed.get_shape())
+
def testPrintGradient(self):
with self.cached_session():
inp = constant_op.constant(2.0, shape=[100, 32], name="in")
diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py
index 38b14e34cc..6791a03e2e 100644
--- a/tensorflow/python/kernel_tests/lookup_ops_test.py
+++ b/tensorflow/python/kernel_tests/lookup_ops_test.py
@@ -21,6 +21,7 @@ import os
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
@@ -29,6 +30,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
@@ -53,6 +55,12 @@ class HashTableOpTest(test.TestCase):
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
+ exported_keys_tensor, exported_values_tensor = table.export()
+
+ self.assertItemsEqual([b"brain", b"salad", b"surgery"],
+ exported_keys_tensor.eval())
+ self.assertItemsEqual([0, 1, 2], exported_values_tensor.eval())
+
def testHashTableFindHighRank(self):
with self.cached_session():
default_val = -1
@@ -181,6 +189,11 @@ class HashTableOpTest(test.TestCase):
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
table.init.run()
+ # Ref types do not produce a lookup signature mismatch.
+ input_string_ref = variables.Variable("brain")
+ variables.global_variables_initializer().run()
+ self.assertEqual(0, table.lookup(input_string_ref).eval())
+
input_string = constant_op.constant([1, 2, 3], dtypes.int64)
with self.assertRaises(TypeError):
table.lookup(input_string)
@@ -261,6 +274,21 @@ class HashTableOpTest(test.TestCase):
table.init.run()
self.assertAllEqual(3, table.size().eval())
+ def testHashTableInt32String(self):
+ with self.cached_session():
+ default_val = "n/a"
+ keys = constant_op.constant([0, 1, 2], dtypes.int32)
+ values = constant_op.constant(["brain", "salad", "surgery"])
+ table = lookup_ops.HashTable(
+ lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
+ table.init.run()
+
+ input_tensor = constant_op.constant([0, 1, -1])
+ output = table.lookup(input_tensor)
+
+ result = output.eval()
+ self.assertAllEqual([b"brain", b"salad", b"n/a"], result)
+
class IndexTableFromFile(test.TestCase):
@@ -335,6 +363,7 @@ class IndexTableFromFile(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
+
feed_dict = {vocabulary_placeholder.name: vocabulary_file}
lookup_ops.tables_initializer().run(feed_dict=feed_dict)
self.assertAllEqual((1, 2, 3), ids.eval())
@@ -531,15 +560,22 @@ class KeyValueTensorInitializerTest(test.TestCase):
class IndexTableFromTensor(test.TestCase):
+ @test_util.run_in_graph_and_eager_modes
def test_index_table_from_tensor_with_tensor_init(self):
- with self.cached_session():
+ table = lookup_ops.index_table_from_tensor(
+ vocabulary_list=("brain", "salad", "surgery"), num_oov_buckets=1)
+
+ if not context.executing_eagerly():
+ with self.assertRaises(errors_impl.OpError):
+ self.evaluate(
+ table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))))
+ else:
+ # Reinitializing a table in eager should work.
table = lookup_ops.index_table_from_tensor(
vocabulary_list=("brain", "salad", "surgery"), num_oov_buckets=1)
- ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
-
- self.assertRaises(errors_impl.OpError, ids.eval)
- lookup_ops.tables_initializer().run()
- self.assertAllEqual((1, 2, 3), ids.eval())
+ self.evaluate(lookup_ops.tables_initializer())
+ ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
+ self.assertAllEqual((1, 2, 3), self.evaluate(ids))
def test_int32_index_table_from_tensor_with_tensor_init(self):
with self.cached_session():
@@ -761,22 +797,20 @@ class InitializeTableFromFileOpTest(test.TestCase):
f.write("\n".join(values) + "\n")
return vocabulary_file
+ @test_util.run_in_graph_and_eager_modes
def testInitializeStringTable(self):
vocabulary_file = self._createVocabFile("one_column_1.txt")
+ default_value = -1
+ table = lookup_ops.HashTable(
+ lookup_ops.TextFileInitializer(
+ vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
+ dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), default_value)
+ self.evaluate(table.init)
- with self.cached_session():
- default_value = -1
- table = lookup_ops.HashTable(
- lookup_ops.TextFileInitializer(
- vocabulary_file, dtypes.string,
- lookup_ops.TextFileIndex.WHOLE_LINE, dtypes.int64,
- lookup_ops.TextFileIndex.LINE_NUMBER), default_value)
- table.init.run()
-
- output = table.lookup(constant_op.constant(["brain", "salad", "tank"]))
+ output = table.lookup(constant_op.constant(["brain", "salad", "tank"]))
- result = output.eval()
- self.assertAllEqual([0, 1, -1], result)
+ result = self.evaluate(output)
+ self.assertAllEqual([0, 1, -1], result)
def testInitializeInt64Table(self):
vocabulary_file = self._createVocabFile(
diff --git a/tensorflow/python/kernel_tests/numerics_test.py b/tensorflow/python/kernel_tests/numerics_test.py
index 89ada8430e..6cc70f7c89 100644
--- a/tensorflow/python/kernel_tests/numerics_test.py
+++ b/tensorflow/python/kernel_tests/numerics_test.py
@@ -66,7 +66,7 @@ class VerifyTensorAllFiniteTest(test.TestCase):
class NumericsTest(test.TestCase):
def testInf(self):
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
t1 = constant_op.constant(1.0)
t2 = constant_op.constant(0.0)
a = math_ops.div(t1, t2)
@@ -76,7 +76,7 @@ class NumericsTest(test.TestCase):
a.eval()
def testNaN(self):
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
t1 = constant_op.constant(0.0)
t2 = constant_op.constant(0.0)
a = math_ops.div(t1, t2)
@@ -86,7 +86,7 @@ class NumericsTest(test.TestCase):
a.eval()
def testBoth(self):
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
t1 = constant_op.constant([1.0, 0.0])
t2 = constant_op.constant([0.0, 0.0])
a = math_ops.div(t1, t2)
@@ -96,7 +96,7 @@ class NumericsTest(test.TestCase):
a.eval()
def testPassThrough(self):
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
t1 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3])
checked = array_ops.check_numerics(t1, message="pass through test")
value = checked.eval()
diff --git a/tensorflow/python/kernel_tests/random/random_ops_test.py b/tensorflow/python/kernel_tests/random/random_ops_test.py
index 0ef6a95cfc..d199a9d9dd 100644
--- a/tensorflow/python/kernel_tests/random/random_ops_test.py
+++ b/tensorflow/python/kernel_tests/random/random_ops_test.py
@@ -320,6 +320,15 @@ class RandomUniformTest(RandomOpTestCommon):
error = np.abs(counts - mean)
self.assertLess(error.max(), 5 * std)
+ # Check that minval = maxval is fine iff we're producing no numbers
+ def testUniformIntsDegenerate(self):
+ for dt in dtypes.int32, dtypes.int64:
+ def sample(n):
+ return self._Sampler(n, minv=0, maxv=0, dtype=dt, use_gpu=True)()
+ self.assertEqual(sample(0).shape, (10, 0))
+ with self.assertRaisesOpError('Need minval < maxval, got 0 >= 0'):
+ sample(1)
+
# Checks that the CPU and GPU implementation returns the same results,
# given the same random seed
def testCPUGPUMatch(self):
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py
index 496a452a03..248036a82a 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test.py
@@ -212,7 +212,7 @@ class SumReductionTest(BaseReductionTest):
arr = np.ones([68000], dtype=np.float16)
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+ with self.session(graph=ops.Graph(), use_gpu=True) as sess:
tf_arr = variables.Variable(arr)
variables.global_variables_initializer().run()
tf_mean = math_ops.reduce_mean(tf_arr, 0, False)
@@ -235,7 +235,7 @@ class SumReductionTest(BaseReductionTest):
col_sum = np.sum(arr, axis=0)
row_sum = np.sum(arr, axis=1)
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+ with self.session(graph=ops.Graph(), use_gpu=True) as sess:
tf_row_sum = self._tf_reduce(arr, 1, False)
tf_col_sum = self._tf_reduce(arr, 0, False)
tf_out_row, tf_out_col = sess.run([tf_row_sum, tf_col_sum])
@@ -249,7 +249,7 @@ class SumReductionTest(BaseReductionTest):
sum_y = np.sum(arr, axis=1)
sum_xz = np.sum(arr, axis=(0, 2))
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+ with self.session(graph=ops.Graph(), use_gpu=True) as sess:
tf_sum_xz = self._tf_reduce(arr, [0, 2], False)
tf_sum_y = self._tf_reduce(arr, 1, False)
tf_out_sum_xz, tf_out_sum_y = sess.run([tf_sum_xz, tf_sum_y])
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test_big.py b/tensorflow/python/kernel_tests/reduction_ops_test_big.py
index d70360775a..1e8524f72a 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test_big.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test_big.py
@@ -63,7 +63,7 @@ class BigReductionTest(BaseReductionTest):
row_sum = np.ones([size_x], dtype=np.float32) * size_y
full_sum = np.ones([], dtype=np.float32) * size_x * size_y
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+ with self.session(graph=ops.Graph(), use_gpu=True) as sess:
tf_row_sum = self._tf_reduce_sum(arr, 1, False)
tf_col_sum = self._tf_reduce_sum(arr, 0, False)
tf_full_sum = self._tf_reduce_sum(arr, [0, 1], False)
@@ -81,7 +81,7 @@ class BigReductionTest(BaseReductionTest):
sum_y = np.ones([size_x, size_z], dtype=np.float32)
sum_xz = np.ones([size_y], dtype=np.float32)
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+ with self.session(graph=ops.Graph(), use_gpu=True) as sess:
tf_sum_xz = self._tf_reduce_mean(arr, [0, 2], False)
tf_sum_y = self._tf_reduce_mean(arr, 1, False)
tf_out_sum_xz, tf_out_sum_y = sess.run([tf_sum_xz, tf_sum_y])
@@ -106,7 +106,7 @@ class BigReductionTest(BaseReductionTest):
row_max = np.max(arr, axis=1)
full_max = np.max(col_max)
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+ with self.session(graph=ops.Graph(), use_gpu=True) as sess:
tf_row_max = self._tf_reduce_max(arr, 1, False)
tf_col_max = self._tf_reduce_max(arr, 0, False)
tf_full_max = self._tf_reduce_max(arr, [0, 1], False)
@@ -125,7 +125,7 @@ class BigReductionTest(BaseReductionTest):
sum_y = np.max(arr, axis=1)
sum_xz = np.max(arr, axis=(0, 2))
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+ with self.session(graph=ops.Graph(), use_gpu=True) as sess:
tf_sum_xz = self._tf_reduce_max(arr, [0, 2], False)
tf_sum_y = self._tf_reduce_max(arr, 1, False)
tf_out_sum_xz, tf_out_sum_y = sess.run([tf_sum_xz, tf_sum_y])
@@ -149,7 +149,7 @@ class BigReductionTest(BaseReductionTest):
row_sum = np.ones([size_x], dtype=np.bool)
full_sum = np.ones([1], dtype=np.bool).reshape([])
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+ with self.session(graph=ops.Graph(), use_gpu=True) as sess:
tf_row_sum = self._tf_reduce_all(arr, 1, False)
tf_col_sum = self._tf_reduce_all(arr, 0, False)
tf_full_sum = self._tf_reduce_all(arr, [0, 1], False)
@@ -167,7 +167,7 @@ class BigReductionTest(BaseReductionTest):
sum_y = np.ones([size_x, size_z], dtype=np.bool)
sum_xz = np.ones([size_y], dtype=np.bool)
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+ with self.session(graph=ops.Graph(), use_gpu=True) as sess:
tf_sum_xz = self._tf_reduce_all(arr, [0, 2], False)
tf_sum_y = self._tf_reduce_all(arr, 1, False)
tf_out_sum_xz, tf_out_sum_y = sess.run([tf_sum_xz, tf_sum_y])
diff --git a/tensorflow/python/kernel_tests/regex_full_match_op_test.py b/tensorflow/python/kernel_tests/regex_full_match_op_test.py
index e81f562a2a..98746e7d9b 100644
--- a/tensorflow/python/kernel_tests/regex_full_match_op_test.py
+++ b/tensorflow/python/kernel_tests/regex_full_match_op_test.py
@@ -42,7 +42,7 @@ class RegexFullMatchOpVariantsTest(test.TestCase, parameterized.TestCase):
def testRegexFullMatchTwoDims(self, op):
values = [["abaaba", "abcdabcde"], ["acdcba", "ebcda"]]
- with self.test_session():
+ with self.cached_session():
input_tensor = constant_op.constant(values, dtypes.string)
matched = op(input_tensor, "a.*a").eval()
self.assertAllEqual([[True, False], [True, False]], matched)
@@ -68,7 +68,7 @@ class RegexFullMatchOpTest(test.TestCase):
def testRegexFullMatchDelegation(self):
with compat.forward_compatibility_horizon(2018, 11, 1):
- with self.test_session():
+ with self.cached_session():
input_tensor = constant_op.constant("foo", dtypes.string)
pattern = "[a-z]"
op = string_ops.regex_full_match(input_tensor, pattern)
@@ -80,7 +80,7 @@ class RegexFullMatchOpTest(test.TestCase):
def testStaticRegexFullMatchDelegation(self):
with compat.forward_compatibility_horizon(2018, 11, 20):
- with self.test_session():
+ with self.cached_session():
input_tensor = constant_op.constant("foo", dtypes.string)
pattern = "[a-z]*"
op = string_ops.regex_full_match(input_tensor, pattern)
diff --git a/tensorflow/python/kernel_tests/regex_replace_op_test.py b/tensorflow/python/kernel_tests/regex_replace_op_test.py
index feac3a8b08..d9b7ed28d2 100644
--- a/tensorflow/python/kernel_tests/regex_replace_op_test.py
+++ b/tensorflow/python/kernel_tests/regex_replace_op_test.py
@@ -33,7 +33,7 @@ from tensorflow.python.platform import test
class RegexReplaceOpVariantsTest(test.TestCase, parameterized.TestCase):
def testForwarding(self, op):
- with self.test_session():
+ with self.cached_session():
# Generate an input that is uniquely consumed by the regex op.
# This exercises code paths which are optimized for this case
# (e.g., using forwarding).
@@ -47,7 +47,7 @@ class RegexReplaceOpVariantsTest(test.TestCase, parameterized.TestCase):
def testRemovePrefix(self, op):
values = ["a:foo", "a:bar", "a:foo", "b:baz", "b:qux", "ca:b"]
- with self.test_session():
+ with self.cached_session():
input_vector = constant_op.constant(values, dtypes.string)
stripped = op(input_vector, "^(a:|b:)", "", replace_global=False).eval()
self.assertAllEqual([b"foo", b"bar", b"foo", b"baz", b"qux", b"ca:b"],
@@ -55,21 +55,21 @@ class RegexReplaceOpVariantsTest(test.TestCase, parameterized.TestCase):
def testRegexReplace(self, op):
values = ["aba\naba", "abcdabcde"]
- with self.test_session():
+ with self.cached_session():
input_vector = constant_op.constant(values, dtypes.string)
stripped = op(input_vector, "a.*a", "(\\0)").eval()
self.assertAllEqual([b"(aba)\n(aba)", b"(abcda)bcde"], stripped)
def testEmptyMatch(self, op):
values = ["abc", "1"]
- with self.test_session():
+ with self.cached_session():
input_vector = constant_op.constant(values, dtypes.string)
stripped = op(input_vector, "", "x").eval()
self.assertAllEqual([b"xaxbxcx", b"x1x"], stripped)
def testInvalidPattern(self, op):
values = ["abc", "1"]
- with self.test_session():
+ with self.cached_session():
input_vector = constant_op.constant(values, dtypes.string)
invalid_pattern = "A["
replace = op(input_vector, invalid_pattern, "x")
@@ -78,7 +78,7 @@ class RegexReplaceOpVariantsTest(test.TestCase, parameterized.TestCase):
def testGlobal(self, op):
values = ["ababababab", "abcabcabc", ""]
- with self.test_session():
+ with self.cached_session():
input_vector = constant_op.constant(values, dtypes.string)
stripped = op(input_vector, "ab", "abc", True).eval()
self.assertAllEqual([b"abcabcabcabcabc", b"abccabccabcc", b""], stripped)
@@ -99,7 +99,7 @@ class RegexReplaceTest(test.TestCase, parameterized.TestCase):
(as_tensor, as_string),
(as_tensor, as_tensor))
def testRegexReplaceDelegation(self, pattern_fn, rewrite_fn):
- with self.test_session():
+ with self.cached_session():
input_vector = constant_op.constant("foo", dtypes.string)
pattern = pattern_fn("[a-z]")
replace = rewrite_fn(".")
@@ -107,7 +107,7 @@ class RegexReplaceTest(test.TestCase, parameterized.TestCase):
self.assertTrue(op.name.startswith("RegexReplace"))
def testStaticRegexReplaceDelegation(self):
- with self.test_session():
+ with self.cached_session():
input_vector = constant_op.constant("foo", dtypes.string)
pattern = "[a-z]"
replace = "."
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index f90545f84c..1365d4b240 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -290,7 +290,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.assertEqual(self.evaluate(read), [[2]])
def testUseResource(self):
- v = variables.Variable(1.0, use_resource=True)
+ v = variables.VariableV1(1.0, use_resource=True)
self.assertTrue(isinstance(v, resource_variable_ops.ResourceVariable))
def testEagerNoUseResource(self):
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index a28cdc3b26..05ad9f6336 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -516,7 +516,7 @@ class RNNTest(test.TestCase):
fix_weights_generator.build((None, input_shape))
weights = fix_weights_generator.get_weights()
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
inputs = array_ops.placeholder(
dtypes.float32, shape=(None, timestep, input_shape))
cell = keras.layers.SimpleRNNCell(output_shape)
@@ -524,7 +524,7 @@ class RNNTest(test.TestCase):
cell, inputs, dtype=dtypes.float32)
cell.set_weights(weights)
[tf_out, tf_state] = sess.run([tf_out, tf_state], {inputs: x_train})
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
k_input = keras.Input(shape=(timestep, input_shape),
dtype=dtypes.float32)
cell = keras.layers.SimpleRNNCell(output_shape)
@@ -536,7 +536,7 @@ class RNNTest(test.TestCase):
self.assertAllClose(tf_state, k_state)
def testBasicLSTMCellInterchangeWithLSTMCell(self):
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
basic_cell = rnn_cell_impl.BasicLSTMCell(1)
basic_cell(array_ops.ones([1, 1]),
state=basic_cell.get_initial_state(inputs=None,
@@ -548,7 +548,7 @@ class RNNTest(test.TestCase):
prefix = os.path.join(self.get_temp_dir(), "ckpt")
save_path = save.save(sess, prefix)
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
lstm_cell = rnn_cell_impl.LSTMCell(1, name="basic_lstm_cell")
lstm_cell(array_ops.ones([1, 1]),
state=lstm_cell.get_initial_state(inputs=None,
diff --git a/tensorflow/python/kernel_tests/scalar_test.py b/tensorflow/python/kernel_tests/scalar_test.py
index 287919bab7..d15f2c7b50 100644
--- a/tensorflow/python/kernel_tests/scalar_test.py
+++ b/tensorflow/python/kernel_tests/scalar_test.py
@@ -53,7 +53,7 @@ class ScalarTest(test.TestCase):
for version in strict + lenient:
with ops.Graph().as_default() as g:
test_util.set_producer_version(g, version)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
feed = {}
xs = placeholders(args, feed)
x = op(*xs)
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
index 86e063cb36..4b92309e4d 100644
--- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -136,7 +136,7 @@ class StatefulScatterNdTest(test.TestCase):
new = ref.copy()
np_scatter(new, indices, updates)
# Scatter via tensorflow
- ref_var = variables.Variable(ref)
+ ref_var = variables.VariableV1(ref)
ref_var.initializer.run()
tf_scatter(ref_var, indices, updates).eval()
@@ -258,7 +258,7 @@ class StatefulScatterNdTest(test.TestCase):
params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
updates = np.array([-3, -4, -5]).astype(np.float32)
with self.test_session(use_gpu=False):
- ref = variables.Variable(params)
+ ref = variables.VariableV1(params)
ref.initializer.run()
# Indices all in range, no problem.
diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py
index 1a0fa744ae..527b7daf10 100644
--- a/tensorflow/python/kernel_tests/scatter_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_ops_test.py
@@ -178,7 +178,7 @@ class ScatterTest(test.TestCase):
np_scatter = _TF_OPS_TO_NUMPY[tf_scatter]
np_scatter(new, indices, updates)
# Scatter via tensorflow
- ref = variables.Variable(old)
+ ref = variables.VariableV1(old)
ref.initializer.run()
tf_scatter(ref, indices, updates).eval()
self.assertAllClose(ref.eval(), new)
@@ -294,7 +294,7 @@ class ScatterTest(test.TestCase):
updates = np.array([-3, -4, -5]).astype(np.float32)
if not test.is_gpu_available():
with self.test_session(use_gpu=False):
- ref = variables.Variable(params)
+ ref = variables.VariableV1(params)
ref.initializer.run()
# Indices all in range, no problem.
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
index ce507e4ad7..2931877c11 100644
--- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
@@ -300,7 +300,7 @@ class UnsortedSegmentTest(SegmentReductionHelper):
tf_ans = s.eval()
if dtype is dtypes_lib.bfloat16:
tf_ans = tf_ans.astype(np.float32)
- self.assertAllClose(np_ans, tf_ans)
+ self.assertAllCloseAccordingToType(np_ans, tf_ans)
self.assertShapeEqual(np_ans, s)
def testNumSegmentsTypes(self):
diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py
index e53347c4bc..89f4697e5c 100644
--- a/tensorflow/python/kernel_tests/softmax_op_test.py
+++ b/tensorflow/python/kernel_tests/softmax_op_test.py
@@ -22,7 +22,6 @@ import unittest
import numpy as np
-from tensorflow.python.compat import compat
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import array_ops
@@ -163,10 +162,9 @@ class SoftmaxTest(test.TestCase):
self._testOverflow(use_gpu=False)
def test1DTensorAsInputNoReshape(self):
- with compat.forward_compatibility_horizon(2018, 8, 27):
- self._testSoftmax(
- np.array([3., 2., 3., 9.]).astype(np.float64), use_gpu=False)
- self._testOverflow(use_gpu=False)
+ self._testSoftmax(
+ np.array([3., 2., 3., 9.]).astype(np.float64), use_gpu=False)
+ self._testOverflow(use_gpu=False)
def test3DTensorAsInput(self):
self._testSoftmax(
@@ -177,13 +175,12 @@ class SoftmaxTest(test.TestCase):
self._testOverflow(use_gpu=False)
def test3DTensorAsInputNoReshape(self):
- with compat.forward_compatibility_horizon(2018, 8, 27):
- self._testSoftmax(
- np.array([[[1., 1., 1., 1.], [1., 2., 3., 4.]],
- [[2., 3., 4., 5.], [6., 7., 8., 9.]],
- [[5., 4., 3., 2.], [1., 2., 3., 4.]]]).astype(np.float32),
- use_gpu=False)
- self._testOverflow(use_gpu=False)
+ self._testSoftmax(
+ np.array([[[1., 1., 1., 1.], [1., 2., 3., 4.]],
+ [[2., 3., 4., 5.], [6., 7., 8., 9.]],
+ [[5., 4., 3., 2.], [1., 2., 3., 4.]]]).astype(np.float32),
+ use_gpu=False)
+ self._testOverflow(use_gpu=False)
def testAlongFirstDimension(self):
self._testSoftmax(
diff --git a/tensorflow/python/kernel_tests/softplus_op_test.py b/tensorflow/python/kernel_tests/softplus_op_test.py
index afe3df6178..636ed4747e 100644
--- a/tensorflow/python/kernel_tests/softplus_op_test.py
+++ b/tensorflow/python/kernel_tests/softplus_op_test.py
@@ -21,7 +21,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import errors
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn_ops
@@ -125,9 +124,9 @@ class SoftplusTest(test.TestCase):
def testNoInts(self):
with self.cached_session():
with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "No OpKernel was registered to support Op 'Softplus'"):
- nn_ops.softplus(constant_op.constant(7)).eval()
+ TypeError,
+ "'features' has DataType int32 not in list of allowed values"):
+ nn_ops.softplus(constant_op.constant(42)).eval()
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/softsign_op_test.py b/tensorflow/python/kernel_tests/softsign_op_test.py
index 05a7c53dee..1b4db9fa46 100644
--- a/tensorflow/python/kernel_tests/softsign_op_test.py
+++ b/tensorflow/python/kernel_tests/softsign_op_test.py
@@ -21,7 +21,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import errors
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
@@ -69,8 +68,8 @@ class SoftsignTest(test.TestCase):
def testNoInts(self):
with self.cached_session():
with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "No OpKernel was registered to support Op 'Softsign'"):
+ TypeError,
+ "'features' has DataType int32 not in list of allowed values"):
nn_ops.softsign(constant_op.constant(7)).eval()
diff --git a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
index 477720302d..a824d5c826 100644
--- a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
+++ b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
@@ -195,7 +195,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
self.assertAllEqual([-1, 2], val.dense_shape)
def testAccumulatorTakeGradSum(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=(), reduction_type="SUM")
@@ -289,7 +289,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
val, sess)
def testParallelApplyGradSum(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32,
name="Q",
diff --git a/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py b/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
index 96793d5af3..31e84341ae 100644
--- a/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
@@ -76,7 +76,7 @@ class SparseTensorsMapTest(test.TestCase):
return sparse_tensor_lib.SparseTensorValue(ind, val, shape)
def testAddTakeMany(self):
- with self.test_session(graph=ops.Graph(), use_gpu=False) as sess:
+ with self.session(graph=ops.Graph(), use_gpu=False) as sess:
sp_input0 = self._SparseTensorValue_5x6(np.arange(6))
sp_input1 = self._SparseTensorValue_3x4(np.arange(6))
handle0 = add_sparse_to_tensors_map(sp_input0, shared_name="a")
diff --git a/tensorflow/python/kernel_tests/string_format_op_test.py b/tensorflow/python/kernel_tests/string_format_op_test.py
new file mode 100644
index 0000000000..74a5072bab
--- /dev/null
+++ b/tensorflow/python/kernel_tests/string_format_op_test.py
@@ -0,0 +1,384 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.kernels.logging_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+class StringFormatOpTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDim(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{}", tensor)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{}", [tensor])
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneVariableScalar(self):
+ with self.cached_session():
+ var = variables.Variable(3.34)
+ format_output = string_ops.string_format("{}", [var])
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ out = self.evaluate(format_output)
+ expected = "3.34"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneVariableOneDim(self):
+ with self.cached_session():
+ var = variables.Variable(math_ops.range(10))
+ format_output = string_ops.string_format("{}", [var])
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatTwoVariablesWithAssignAdd(self):
+ with self.cached_session():
+ var_one = variables.Variable(2.14)
+ plus_one = var_one.assign_add(1.0)
+ var_two = variables.Variable(math_ops.range(10))
+ format_output = string_ops.string_format("{}, {}", [var_one, var_two])
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ self.evaluate(plus_one)
+ out = self.evaluate(format_output)
+ expected = "3.14, [0 1 2 ... 7 8 9]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDimFloat(self):
+ with self.cached_session():
+ tensor = constant_op.constant([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
+ format_output = string_ops.string_format("{}", tensor)
+ out = self.evaluate(format_output)
+ expected = "[0 0.1 0.2 ... 0.5 0.6 0.7]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDimMatchesSummarize(self):
+ with self.cached_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=3)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 3 4 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDimVarySummarize(self):
+ with self.cached_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=-1)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 3 4 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ with self.cached_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=1)
+ out = self.evaluate(format_output)
+ expected = "[0 ... 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ with self.cached_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=2)
+ out = self.evaluate(format_output)
+ expected = "[0 1 ... 4 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ with self.cached_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=10)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 3 4 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDimAlmostSummarize(self):
+ with self.cached_session():
+ tensor = math_ops.range(5)
+ format_output = string_ops.string_format("{}", tensor, summarize=3)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 3 4]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTwoDimLessThanSummarize(self):
+ with self.cached_session():
+ tensor = array_ops.reshape(math_ops.range(4), [2, 2])
+ format_output = string_ops.string_format("{}", tensor, summarize=3)
+ out = self.evaluate(format_output)
+ expected = ("[[0 1]\n"
+ " [2 3]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTwoDim(self):
+ with self.cached_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("{}", tensor)
+ out = self.evaluate(format_output)
+ expected = ("[[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTwoDimSummarizeTwo(self):
+ with self.cached_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("{}", tensor, summarize=2)
+ out = self.evaluate(format_output)
+ expected = ("[[0 1 ... 8 9]\n"
+ " [10 11 ... 18 19]\n"
+ " ...\n"
+ " [80 81 ... 88 89]\n"
+ " [90 91 ... 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorThreeDim(self):
+ with self.cached_session():
+ tensor = array_ops.reshape(math_ops.range(1000), [10, 10, 10])
+ format_output = string_ops.string_format("{}", tensor)
+ out = self.evaluate(format_output)
+ expected = ("[[[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]]\n"
+ "\n"
+ " [[100 101 102 ... 107 108 109]\n"
+ " [110 111 112 ... 117 118 119]\n"
+ " [120 121 122 ... 127 128 129]\n"
+ " ...\n [170 171 172 ... 177 178 179]\n"
+ " [180 181 182 ... 187 188 189]\n"
+ " [190 191 192 ... 197 198 199]]\n"
+ "\n"
+ " [[200 201 202 ... 207 208 209]\n"
+ " [210 211 212 ... 217 218 219]\n"
+ " [220 221 222 ... 227 228 229]\n"
+ " ...\n"
+ " [270 271 272 ... 277 278 279]\n"
+ " [280 281 282 ... 287 288 289]\n"
+ " [290 291 292 ... 297 298 299]]\n"
+ "\n"
+ " ...\n"
+ "\n"
+ " [[700 701 702 ... 707 708 709]\n"
+ " [710 711 712 ... 717 718 719]\n"
+ " [720 721 722 ... 727 728 729]\n"
+ " ...\n"
+ " [770 771 772 ... 777 778 779]\n"
+ " [780 781 782 ... 787 788 789]\n"
+ " [790 791 792 ... 797 798 799]]\n"
+ "\n"
+ " [[800 801 802 ... 807 808 809]\n"
+ " [810 811 812 ... 817 818 819]\n"
+ " [820 821 822 ... 827 828 829]\n"
+ " ...\n"
+ " [870 871 872 ... 877 878 879]\n"
+ " [880 881 882 ... 887 888 889]\n"
+ " [890 891 892 ... 897 898 899]]\n"
+ "\n"
+ " [[900 901 902 ... 907 908 909]\n"
+ " [910 911 912 ... 917 918 919]\n"
+ " [920 921 922 ... 927 928 929]\n"
+ " ...\n"
+ " [970 971 972 ... 977 978 979]\n"
+ " [980 981 982 ... 987 988 989]\n"
+ " [990 991 992 ... 997 998 999]]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTemplatePrefix(self):
+ with self.cached_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: {}", tensor)
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTemplatePrefixAndSuffix(self):
+ with self.cached_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: {}, suffix",
+ tensor)
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]], suffix")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTemplateSuffix(self):
+ with self.cached_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("{}, suffix", tensor)
+ out = self.evaluate(format_output)
+ expected = ("[[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]], suffix")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatNoTensor(self):
+ with self.cached_session():
+ format_output = string_ops.string_format("No tensor.", ())
+ out = self.evaluate(format_output)
+ expected = "No tensor."
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatMultiTensor(self):
+ with self.cached_session():
+ tensor_one = array_ops.reshape(math_ops.range(100), [10, 10])
+ tensor_two = tensor_one * 10
+ format_output = string_ops.string_format("One: {},\nTwo: {}",
+ (tensor_one, tensor_two))
+ out = self.evaluate(format_output)
+ expected = ("One: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]],\n"
+ "Two: [[0 10 20 ... 70 80 90]\n"
+ " [100 110 120 ... 170 180 190]\n"
+ " [200 210 220 ... 270 280 290]\n"
+ " ...\n"
+ " [700 710 720 ... 770 780 790]\n"
+ " [800 810 820 ... 870 880 890]\n"
+ " [900 910 920 ... 970 980 990]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatSummarizeOne(self):
+ with self.cached_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: {}", tensor,
+ summarize=1)
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 ... 9]\n"
+ " ...\n"
+ " [90 ... 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatSummarizeTwo(self):
+ with self.cached_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: {}", tensor,
+ summarize=2)
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 1 ... 8 9]\n"
+ " [10 11 ... 18 19]\n"
+ " ...\n"
+ " [80 81 ... 88 89]\n"
+ " [90 91 ... 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatPlaceholder(self):
+ with self.cached_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: %t%", tensor,
+ placeholder="%t%")
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testTensorCountMustMatchPlaceholderCount(self):
+ with self.cached_session():
+ with self.assertRaisesRegexp(
+ ValueError, r"2 placeholder\(s\) in template does not match 1 "
+ r"tensor\(s\) provided as input"):
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{} {}", tensor)
+ self.evaluate(format_output)
+ with self.cached_session():
+ with self.assertRaisesRegexp(
+ ValueError, r"2 placeholder\(s\) in template does not match 1 "
+ r"tensor\(s\) provided as input"):
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{} {}", [tensor])
+ self.evaluate(format_output)
+ with self.cached_session():
+ with self.assertRaisesRegexp(
+ ValueError, r"1 placeholder\(s\) in template does not match 2 "
+ r"tensor\(s\) provided as input"):
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{}", (tensor, tensor))
+ self.evaluate(format_output)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/string_length_op_test.py b/tensorflow/python/kernel_tests/string_length_op_test.py
index 9f013c2c7e..4afe3ad3f4 100644
--- a/tensorflow/python/kernel_tests/string_length_op_test.py
+++ b/tensorflow/python/kernel_tests/string_length_op_test.py
@@ -32,6 +32,33 @@ class StringLengthOpTest(test.TestCase):
values = sess.run(lengths)
self.assertAllEqual(values, [[[1, 2], [3, 4], [5, 6]]])
+ def testUnit(self):
+ unicode_strings = [u"H\xc3llo", u"\U0001f604"]
+ utf8_strings = [s.encode("utf-8") for s in unicode_strings]
+ expected_utf8_byte_lengths = [6, 4]
+ expected_utf8_char_lengths = [5, 1]
+
+ with self.test_session() as sess:
+ utf8_byte_lengths = string_ops.string_length(utf8_strings, unit="BYTE")
+ utf8_char_lengths = string_ops.string_length(
+ utf8_strings, unit="UTF8_CHAR")
+ self.assertAllEqual(
+ sess.run(utf8_byte_lengths), expected_utf8_byte_lengths)
+ self.assertAllEqual(
+ sess.run(utf8_char_lengths), expected_utf8_char_lengths)
+ with self.assertRaisesRegexp(
+ ValueError, "Attr 'unit' of 'StringLength' Op passed string 'XYZ' "
+ 'not in: "BYTE", "UTF8_CHAR"'):
+ string_ops.string_length(utf8_strings, unit="XYZ")
+
+ def testLegacyPositionalName(self):
+ # Code that predates the 'unit' parameter may have used a positional
+ # argument for the 'name' parameter. Check that we don't break such code.
+ strings = [[["1", "12"], ["123", "1234"], ["12345", "123456"]]]
+ lengths = string_ops.string_length(strings, "some_name")
+ with self.test_session():
+ self.assertAllEqual(lengths.eval(), [[[1, 2], [3, 4], [5, 6]]])
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/substr_op_test.py b/tensorflow/python/kernel_tests/substr_op_test.py
index 4d163a0f6f..cd3fe14883 100644
--- a/tensorflow/python/kernel_tests/substr_op_test.py
+++ b/tensorflow/python/kernel_tests/substr_op_test.py
@@ -46,7 +46,7 @@ class SubstrOpTest(test.TestCase, parameterized.TestCase):
expected_value = b"ell"
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
@@ -57,7 +57,7 @@ class SubstrOpTest(test.TestCase, parameterized.TestCase):
expected_value = b""
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
@@ -79,7 +79,7 @@ class SubstrOpTest(test.TestCase, parameterized.TestCase):
expected_value = [b"ell", b"orl"]
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
@@ -104,7 +104,7 @@ class SubstrOpTest(test.TestCase, parameterized.TestCase):
[b"ixte", b"even", b"ight"]]
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
@@ -196,7 +196,7 @@ class SubstrOpTest(test.TestCase, parameterized.TestCase):
position = np.array(-7, dtype)
length = np.array(3, dtype)
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
substr = substr_op.eval()
@@ -234,7 +234,7 @@ class SubstrOpTest(test.TestCase, parameterized.TestCase):
position = np.array([[1, 2, -3], [1, 2, -4], [1, 2, -3]], dtype)
length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
substr = substr_op.eval()
@@ -252,7 +252,7 @@ class SubstrOpTest(test.TestCase, parameterized.TestCase):
position = np.array([-1, -2, -4], dtype)
length = np.array([1, 2, 3], dtype)
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
substr = substr_op.eval()
diff --git a/tensorflow/python/kernel_tests/summary_audio_op_test.py b/tensorflow/python/kernel_tests/summary_audio_op_test.py
index eaae671192..e59a2ceef7 100644
--- a/tensorflow/python/kernel_tests/summary_audio_op_test.py
+++ b/tensorflow/python/kernel_tests/summary_audio_op_test.py
@@ -50,7 +50,7 @@ class SummaryAudioOpTest(test.TestCase):
def testAudioSummary(self):
np.random.seed(7)
for channels in (1, 2, 5, 8):
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
num_frames = 7
shape = (4, num_frames, channels)
# Generate random audio in the range [-1.0, 1.0).
diff --git a/tensorflow/python/kernel_tests/summary_image_op_test.py b/tensorflow/python/kernel_tests/summary_image_op_test.py
index 4718827e88..b650e10404 100644
--- a/tensorflow/python/kernel_tests/summary_image_op_test.py
+++ b/tensorflow/python/kernel_tests/summary_image_op_test.py
@@ -52,7 +52,7 @@ class SummaryImageOpTest(test.TestCase):
def testImageSummary(self):
for depth in (1, 3, 4):
for positive in False, True:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
shape = (4, 5, 7) + (depth,)
bad_color = [255, 0, 0, 255][:depth]
# Build a mostly random image with one nan
@@ -87,7 +87,7 @@ class SummaryImageOpTest(test.TestCase):
def testImageSummaryUint8(self):
np.random.seed(7)
for depth in (1, 3, 4):
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
shape = (4, 5, 7) + (depth,)
# Build a random uint8 image
diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
index 6de6fbe767..0ad2063558 100644
--- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py
+++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
@@ -1504,6 +1504,19 @@ class TensorArrayTest(test.TestCase):
vdx, vdy = sess.run([dx, dy])
self.assertAllClose(vdx, vdy)
+ def testTensorArrayInt64GPU(self):
+ if not test.is_gpu_available():
+ return
+ with self.test_session(use_gpu=True, force_gpu=True) as sess:
+ value = array_ops.placeholder(dtypes.int64)
+ ta = tensor_array_ops.TensorArray(dtype=dtypes.int64, size=2)
+ ta = ta.scatter([0, 1], value)
+ r0 = ta.read(0)
+ r1 = ta.read(1)
+ v0, v1 = sess.run([r0, r1], feed_dict={value: [-3, 100]})
+ self.assertAllEqual(v0, -3)
+ self.assertAllEqual(v1, 100)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/unicode_script_op_test.py b/tensorflow/python/kernel_tests/unicode_script_op_test.py
new file mode 100644
index 0000000000..927e5459ed
--- /dev/null
+++ b/tensorflow/python/kernel_tests/unicode_script_op_test.py
@@ -0,0 +1,57 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#===============================================================================
+"""Functional tests for UnicodeScript op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class UnicodeScriptOpTest(test.TestCase):
+
+ def testValidScripts(self):
+ inputs = [
+ ord("a"),
+ 0x0411, # CYRILLIC CAPITAL LETTER BE
+ 0x82b8, # CJK UNIFIED IDEOGRAPH-82B8
+ ord(",")
+ ]
+ with self.cached_session():
+ input_vector = constant_op.constant(inputs, dtypes.int32)
+ outputs = string_ops.unicode_script(input_vector).eval()
+ self.assertAllEqual(
+ outputs,
+ [
+ 25, # USCRIPT_LATIN (LATN)
+ 8, # USCRIPT_CYRILLIC (CYRL)
+ 17, # USCRIPT_HAN (HANI)
+ 0 # USCRIPT_COMMON (ZYYY)
+ ])
+
+ def testInvalidScript(self):
+ inputs = [-100, 0xffffff]
+ with self.cached_session():
+ input_vector = constant_op.constant(inputs, dtypes.int32)
+ outputs = string_ops.unicode_script(input_vector).eval()
+ self.assertAllEqual(outputs, [-1, -1])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index 401e1ae102..33f464fb90 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -394,10 +394,10 @@ class VariableScopeTest(test.TestCase):
old = variable_scope._DEFAULT_USE_RESOURCE
try:
variable_scope.enable_resource_variables()
- self.assertTrue(isinstance(variables_lib.Variable(1.0),
+ self.assertTrue(isinstance(variables_lib.VariableV1(1.0),
resource_variable_ops.ResourceVariable))
variable_scope.disable_resource_variables()
- self.assertFalse(isinstance(variables_lib.Variable(1.0),
+ self.assertFalse(isinstance(variables_lib.VariableV1(1.0),
resource_variable_ops.ResourceVariable))
finally:
variable_scope._DEFAULT_USE_RESOURCE = old
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 2e7975667c..942ceedc8b 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -43,14 +43,14 @@ class VariablesTestCase(test.TestCase):
def testInitialization(self):
with self.cached_session():
- var0 = variables.Variable(0.0)
+ var0 = variables.VariableV1(0.0)
self.assertEqual("Variable:0", var0.name)
self.assertEqual("Variable", var0._shared_name)
self.assertEqual([], var0.get_shape())
self.assertEqual([], var0.get_shape())
self.assertEqual([], var0.shape)
- var1 = variables.Variable(1.1)
+ var1 = variables.VariableV1(1.1)
self.assertEqual("Variable_1:0", var1.name)
self.assertEqual("Variable_1", var1._shared_name)
self.assertEqual([], var1.get_shape())
@@ -143,7 +143,7 @@ class VariablesTestCase(test.TestCase):
def testZeroSizeStringAssign(self):
with self.cached_session() as sess:
- array = variables.Variable(
+ array = variables.VariableV1(
initial_value=array_ops.zeros((0,), dtype=dtypes.string),
name="foo",
trainable=False,
@@ -192,7 +192,7 @@ class VariablesTestCase(test.TestCase):
# d get the control dep.
d = constant_op.constant(2.0)
# variables do not.
- var_x = variables.Variable(2.0)
+ var_x = variables.VariableV1(2.0)
self.assertEqual([c.op], d.op.control_inputs)
self.assertEqual([], var_x.initializer.control_inputs)
self.assertEqual([], var_x.value().op.control_inputs)
@@ -280,10 +280,10 @@ class VariablesTestCase(test.TestCase):
def testCollections(self):
with self.cached_session():
- var_x = variables.Variable(2.0)
- var_y = variables.Variable(2.0, trainable=False)
- var_z = variables.Variable(2.0, trainable=True)
- var_t = variables.Variable(
+ var_x = variables.VariableV1(2.0)
+ var_y = variables.VariableV1(2.0, trainable=False)
+ var_z = variables.VariableV1(2.0, trainable=True)
+ var_t = variables.VariableV1(
2.0,
trainable=True,
collections=[
@@ -296,9 +296,9 @@ class VariablesTestCase(test.TestCase):
def testCollectionsWithScope(self):
with self.cached_session():
with ops.name_scope("scope_1"):
- var_x = variables.Variable(2.0)
+ var_x = variables.VariableV1(2.0)
with ops.name_scope("scope_2"):
- var_y = variables.Variable(2.0)
+ var_y = variables.VariableV1(2.0)
self.assertEqual([var_x, var_y], variables.global_variables())
self.assertEqual([var_x], variables.global_variables("scope_1"))
@@ -399,7 +399,7 @@ class VariablesTestCase(test.TestCase):
def testColocation(self):
with ops.device("/job:ps"):
- var = variables.Variable(0, name="v")
+ var = variables.VariableV1(0, name="v")
with ops.device("/job:worker/task:7"):
assign_op = var.assign(1)
self.assertDeviceEqual("/job:ps", assign_op.device)
@@ -522,7 +522,7 @@ class VariablesTestCase(test.TestCase):
self.assertAllClose(np.ones((5, 5), np.float32), var.eval())
def testRepr(self):
- var = variables.Variable(np.zeros((5, 5), np.float32), name="noop")
+ var = variables.VariableV1(np.zeros((5, 5), np.float32), name="noop")
self.assertEqual(
"<tf.Variable 'noop:0' shape=(5, 5) dtype=float32_ref>",
repr(var))
@@ -556,8 +556,8 @@ class IsInitializedTest(test.TestCase):
def testVariableList(self):
with ops.Graph().as_default(), self.cached_session() as sess:
- v = variables.Variable([1, 2], name="v")
- w = variables.Variable([3, 4], name="w")
+ v = variables.VariableV1([1, 2], name="v")
+ w = variables.VariableV1([3, 4], name="w")
uninited = variables.report_uninitialized_variables()
self.assertAllEqual(np.array([b"v", b"w"]), sess.run(uninited))
sess.run(w.initializer)
@@ -593,8 +593,8 @@ class ObsoleteIsInitializedTest(test.TestCase):
def testVariables(self):
with ops.Graph().as_default(), self.cached_session() as sess:
- v = variables.Variable([1, 2])
- w = variables.Variable([3, 4])
+ v = variables.VariableV1([1, 2])
+ w = variables.VariableV1([3, 4])
_ = v, w
inited = variables.assert_variables_initialized()
with self.assertRaisesOpError("Attempting to use uninitialized value"):
@@ -604,8 +604,8 @@ class ObsoleteIsInitializedTest(test.TestCase):
def testVariableList(self):
with ops.Graph().as_default(), self.cached_session() as sess:
- v = variables.Variable([1, 2])
- w = variables.Variable([3, 4])
+ v = variables.VariableV1([1, 2])
+ w = variables.VariableV1([3, 4])
inited = variables.assert_variables_initialized([v])
with self.assertRaisesOpError("Attempting to use uninitialized value"):
inited.op.run()
diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py
new file mode 100644
index 0000000000..3a070544e8
--- /dev/null
+++ b/tensorflow/python/kernel_tests/while_v2_test.py
@@ -0,0 +1,276 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for while_v2."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import meta_graph
+from tensorflow.python.framework import ops
+from tensorflow.python.grappler import tf_optimizer
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import while_v2
+from tensorflow.python.ops.control_flow_ops import while_loop as while_loop_v1
+from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
+from tensorflow.python.platform import test
+
+
+class WhileV2Test(test.TestCase, parameterized.TestCase):
+
+ def testSingleLoopVar(self):
+ x = constant_op.constant(2.)
+ ret = while_loop_v2(lambda v: v < 8., lambda v: v * v, [x])
+ grad = gradients_impl.gradients(ret, [x])
+ with self.cached_session() as sess:
+ self.assertEqual(sess.run(ret), 16.)
+ self.assertSequenceEqual(sess.run(grad), [32.])
+
+ def testMultipleLoopVarsBasic(self):
+ x = constant_op.constant(5.)
+ y = constant_op.constant(3.)
+
+ # x = 5.
+ # y = 3.
+ # while x < 45.:
+ # x = x * y
+ ret = while_loop_v2(lambda v, _: v < 45., lambda v, w: (v * w, w), [x, y])
+ # ret = [x*y^2, y]
+
+ # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
+ grad = gradients_impl.gradients(ret, [x]) # [2*x*y]
+ with self.cached_session() as sess:
+ self.assertSequenceEqual(sess.run(ret), [45., 3.])
+ self.assertSequenceEqual(sess.run(grad), [9.])
+
+ def testMultipleLoopVars(self):
+ x = constant_op.constant(5.)
+ y = constant_op.constant(3.)
+
+ # x = 5.
+ # y = 3.
+ # while x < 45.:
+ # x = x * y
+ # y = x + y
+ ret = while_loop_v2(lambda v, _: v < 45., lambda v, w: (v * w, v + w),
+ [x, y])
+ # ret = [y*x**2 + x*y**2, x*y + x + y]
+
+ gradx_0 = gradients_impl.gradients(ret[0], [x]) # [2*x*y + y**2]
+ gradx_1 = gradients_impl.gradients(ret[1], [x]) # [y + 1]
+ gradx_2 = gradients_impl.gradients(ret, [x]) # [2*x*y + y**2 + 2*y + 1]
+ grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2]
+ grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1]
+ grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1]
+ with self.cached_session() as sess:
+ self.assertSequenceEqual(sess.run(ret), [120., 23.])
+ self.assertSequenceEqual(sess.run(gradx_0), [39.])
+ self.assertSequenceEqual(sess.run(gradx_1), [4.])
+ self.assertSequenceEqual(sess.run(gradx_2), [43.])
+ self.assertSequenceEqual(sess.run(grady_0), [55.])
+ self.assertSequenceEqual(sess.run(grady_1), [6.])
+ self.assertSequenceEqual(sess.run(grady_2), [61.])
+
+ def testMultipleWhileLoops(self):
+ x = constant_op.constant(2.)
+ ret1 = while_loop_v2(lambda v: v < 4., lambda v: v * v, [x]) # x**2
+ ret2 = while_loop_v2(lambda v: v < 16., lambda v: v * v, ret1) # x**4
+ grad = gradients_impl.gradients(ret2, [x]) # 4x**3
+ grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
+ with self.cached_session() as sess:
+ self.assertSequenceEqual(sess.run(grad), [32.])
+ self.assertSequenceEqual(sess.run(grad_grad), [48.])
+
+ def testDoubleDerivative(self):
+ x = constant_op.constant(2.)
+ ret = while_loop_v2(lambda v: v < 8., lambda v: v**2, [x]) # x**4
+ grad = gradients_impl.gradients(ret, [x]) # 4x**3
+ grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
+ with self.cached_session() as sess:
+ self.assertEqual(sess.run(ret), 16.)
+ self.assertSequenceEqual(sess.run(grad), [32.])
+ self.assertSequenceEqual(sess.run(grad_grad), [48.])
+
+ def testPruning(self):
+ x = constant_op.constant(1)
+
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=x.dtype, element_shape=x.shape)
+
+ def Cond(x, tl):
+ del tl # Unused for Cond.
+ return x < 5
+
+ def Body(x, tl):
+ return x + 1, list_ops.tensor_list_push_back(tl, x)
+
+ outputs = while_loop_v1(Cond, Body, [x, tensor_list])
+
+ train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+ train_op.append(outputs[0])
+
+ def GetOptimizedGraph():
+ mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
+ memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
+ return tf_optimizer.OptimizeGraph(rewriter_config, mg)
+
+ g = GetOptimizedGraph()
+ self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1)
+
+ stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
+ train_op.append(stack)
+ g = GetOptimizedGraph()
+ self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
+
+ def testCaptureExternalTensorInCond(self):
+ x = constant_op.constant(2.)
+ y = constant_op.constant(1.)
+ ret = while_loop_v2(lambda v: v + y < 9., lambda v: v * 3., [x])
+ grad = gradients_impl.gradients(ret, [x])
+ with self.cached_session() as sess:
+ self.assertEqual(sess.run(ret), 18.)
+ self.assertSequenceEqual(sess.run(grad), [9.])
+
+ def testCaptureExternalTensorInBody(self):
+ x = constant_op.constant(2.)
+ y = constant_op.constant(3.)
+ ret = while_loop_v2(lambda v: v < 8., lambda v: v * y, [x])
+ grad = gradients_impl.gradients(ret, [x])
+ with self.cached_session() as sess:
+ self.assertEqual(sess.run(ret), 18.)
+ self.assertSequenceEqual(sess.run(grad), [9.])
+
+ def testLoopWithTensorListPushBack(self):
+ x = constant_op.constant(2.)
+
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32, element_shape=ScalarShape())
+
+ def Cond(x, tl):
+ del tl # Unused for Cond.
+ return x < 5.
+
+ def Body(x, tl):
+ tl = list_ops.tensor_list_push_back(tl, x)
+ tl = list_ops.tensor_list_push_back(tl, constant_op.constant(100.))
+ return x**2., tl
+
+ ret = while_loop_v2(Cond, Body, [x, tensor_list])
+ grad = gradients_impl.gradients(ret[0], x)
+ with self.cached_session() as sess:
+ self.assertEqual(sess.run(ret[0]), 16.)
+ self.assertSequenceEqual(sess.run(grad), [32.])
+
+ def testDuplicateAccumulator(self):
+ x = constant_op.constant(2.)
+
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32, element_shape=ScalarShape())
+
+ def Cond(x, tl):
+ del tl # Unused for Cond.
+ return x < 5.
+
+ def Body(x, tl):
+ # There is an accumulator in the loop already so we should not add
+ # another.
+ tl = list_ops.tensor_list_push_back(tl, x)
+ return x**2., tl
+
+ ret = while_loop_v2(Cond, Body, [x, tensor_list])
+
+ for op in ops.get_default_graph().get_operations():
+ if op.type == "While":
+ while_op = op
+
+ body_graph = while_v2._get_body_graph(while_op)
+ # body_graph.inputs: [counter_arg, x_arg, tl_arg, *accumulators]
+ x_input_t = body_graph.inputs[1]
+ accumulator_count = len(
+ [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"])
+ self.assertEqual(accumulator_count, 1)
+
+ grad = gradients_impl.gradients(ret[0], x)
+ with self.cached_session() as sess:
+ self.assertEqual(sess.run(ret[0]), 16.)
+ self.assertSequenceEqual(sess.run(grad), [32.])
+
+ @parameterized.named_parameters(
+ ("UnknownShape", None),
+ ("PartiallyDefinedShape", [None, 2]),
+ ("FullyDefinedShape", [1, 2]),
+ )
+ def testTensorListOutputElementShape(self, shape):
+
+ def MatchShape(actual_tensor_shape):
+ # Compare the shapes, treating None dimensions as equal. We do not
+ # directly check actual_tensor_shape and tf.TensorShape(shape) for
+ # equality because tf.Dimension.__eq__ returns None if either dimension is
+ # None.
+ if shape is None:
+ self.assertIsNone(actual_tensor_shape.dims)
+ else:
+ self.assertListEqual(actual_tensor_shape.as_list(), shape)
+
+ def GetAccumulatorForInputAtIndex(while_op, idx):
+ body_graph = while_v2._get_body_graph(while_op)
+ y_input_t = body_graph.inputs[idx]
+ push_back_node = [c for c in y_input_t.consumers()
+ if c.type == "TensorListPushBack"][0]
+ output_idx = body_graph.outputs.index(push_back_node.outputs[0])
+ return while_op.outputs[output_idx]
+
+ x = constant_op.constant(2.)
+ y = array_ops.placeholder(dtype=dtypes.float32, shape=shape)
+
+ # Forward pass.
+ ret = while_loop_v2(lambda v, u: v < 8., lambda v, u: (v * v, u), [x, y])
+ while_op = ret[0].op
+ # Get the TensorList output of While op containing the accumulated values
+ # of y.
+ # while_op.inputs: [counter_arg, x_arg, y_arg, *accumulators]
+ output = GetAccumulatorForInputAtIndex(while_op, 2)
+ _, val = list_ops.tensor_list_pop_back(output,
+ element_dtype=dtypes.float32)
+ MatchShape(val.shape)
+
+ # Gradient pass.
+ grad = gradients_impl.gradients(ret[1], y)
+ grad_while_op = grad[0].op
+ # Get the TensorList output of gradient While op containing the accumulated
+ # values of grad_y.
+ # grad_while_op.inputs:
+ # [counter_arg, total_iters_arg, grad_x_arg, grad_y_arg, *other_args]
+ grad_output = GetAccumulatorForInputAtIndex(grad_while_op, 4)
+ _, val = list_ops.tensor_list_pop_back(grad_output,
+ element_dtype=dtypes.float32)
+ MatchShape(val.shape)
+
+
+def ScalarShape():
+ return ops.convert_to_tensor([], dtype=dtypes.int32)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 3ba880d7a1..e399ece232 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -131,10 +131,20 @@ class Layer(base_layer.Layer):
def add_loss(self, losses, inputs=None):
previous_losses_length = len(self._losses)
+ previous_callable_losses_length = len(self._callable_losses)
super(Layer, self).add_loss(losses, inputs=inputs)
- # TODO(fchollet): deprecate collection below.
- new_losses = self._losses[previous_losses_length:]
- _add_elements_to_collection(new_losses, ops.GraphKeys.REGULARIZATION_LOSSES)
+ if not context.executing_eagerly():
+ # TODO(fchollet): deprecate collection below.
+ new_losses = self._losses[previous_losses_length:]
+ new_callable_losses = self._callable_losses[
+ previous_callable_losses_length:]
+ for regularizer in new_callable_losses:
+ loss_tensor = regularizer()
+ if loss_tensor is not None:
+ new_losses.append(loss_tensor)
+ _add_elements_to_collection(
+ new_losses,
+ ops.GraphKeys.REGULARIZATION_LOSSES)
def _name_scope(self):
"""Determines op naming for the Layer."""
diff --git a/tensorflow/python/layers/convolutional_test.py b/tensorflow/python/layers/convolutional_test.py
index d61d3b6dba..257fa27156 100644
--- a/tensorflow/python/layers/convolutional_test.py
+++ b/tensorflow/python/layers/convolutional_test.py
@@ -207,7 +207,8 @@ class ConvTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DBiasRegularizer(self):
height, width = 7, 9
@@ -217,7 +218,8 @@ class ConvTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DNoBias(self):
height, width = 7, 9
@@ -445,7 +447,8 @@ class SeparableConv1DTest(test.TestCase):
layer.apply(data)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv1DPointwiseRegularizer(self):
length = 9
@@ -455,7 +458,8 @@ class SeparableConv1DTest(test.TestCase):
layer.apply(data)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv1DBiasRegularizer(self):
length = 9
@@ -465,7 +469,8 @@ class SeparableConv1DTest(test.TestCase):
layer.apply(data)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv1DNoBias(self):
length = 9
@@ -682,7 +687,8 @@ class SeparableConv2DTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv2DPointwiseRegularizer(self):
height, width = 7, 9
@@ -692,7 +698,8 @@ class SeparableConv2DTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv2DBiasRegularizer(self):
height, width = 7, 9
@@ -702,7 +709,8 @@ class SeparableConv2DTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv2DNoBias(self):
height, width = 7, 9
@@ -839,7 +847,8 @@ class Conv2DTransposeTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DTransposeBiasRegularizer(self):
height, width = 7, 9
@@ -849,7 +858,8 @@ class Conv2DTransposeTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DTransposeNoBias(self):
height, width = 7, 9
@@ -1017,7 +1027,8 @@ class Conv3DTransposeTest(test.TestCase):
layer.apply(volumes)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv3DTransposeBiasRegularizer(self):
depth, height, width = 5, 7, 9
@@ -1027,7 +1038,8 @@ class Conv3DTransposeTest(test.TestCase):
layer.apply(volumes)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv3DTransposeNoBias(self):
depth, height, width = 5, 7, 9
diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py
index 46009a30ac..d26f3f4789 100644
--- a/tensorflow/python/layers/core_test.py
+++ b/tensorflow/python/layers/core_test.py
@@ -197,7 +197,8 @@ class DenseTest(test.TestCase):
_ = dense(inputs)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(dense.losses, loss_keys)
+ self.evaluate([v.initializer for v in dense.variables])
+ self.assertAllEqual(self.evaluate(dense.losses), self.evaluate(loss_keys))
def testKernelRegularizerWithReuse(self):
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
@@ -218,7 +219,8 @@ class DenseTest(test.TestCase):
_ = dense(inputs)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(dense.losses, loss_keys)
+ self.evaluate([v.initializer for v in dense.variables])
+ self.assertAllEqual(self.evaluate(dense.losses), self.evaluate(loss_keys))
def testFunctionalDense(self):
with self.cached_session():
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index c8b883350d..a7f57e94e3 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -2787,4 +2787,65 @@ def quantize(input, # pylint: disable=redefined-builtin
name=name)
+@tf_export("searchsorted")
+def searchsorted(sorted_sequence,
+ values,
+ side="left",
+ out_type=dtypes.int32,
+ name=None):
+ """Searches input tensor for values on the innermost dimension.
+
+ A 2-D example:
+
+ ```
+ sorted_sequence = [[0, 3, 9, 9, 10],
+ [1, 2, 3, 4, 5]]
+ values = [[2, 4, 9],
+ [0, 2, 6]]
+
+ result = searchsorted(sorted_sequence, values, side="left")
+
+ result == [[1, 2, 2],
+ [0, 1, 5]]
+
+ result = searchsorted(sorted_sequence, values, side="right")
+
+ result == [[1, 2, 4],
+ [0, 2, 5]]
+ ```
+
+ Args:
+ sorted_sequence: N-D `Tensor` containing a sorted sequence.
+ values: N-D `Tensor` containing the search values.
+ side: 'left' or 'right'; 'left' corresponds to lower_bound and 'right' to
+ upper_bound.
+ out_type: The output type (`int32` or `int64`). Default is `tf.int32`.
+ name: Optional name for the operation.
+
+ Returns:
+ An N-D `Tensor` the size of values containing the result of applying either
+ lower_bound or upper_bound (depending on side) to each value. The result
+ is not a global index to the entire `Tensor`, but the index in the last
+ dimension.
+
+ Raises:
+ ValueError: If the last dimension of `sorted_sequence >= 2^31-1` elements.
+ If the total size of values exceeds `2^31 - 1` elements.
+ If the first `N-1` dimensions of the two tensors don't match.
+ """
+ sequence_size = shape_internal(sorted_sequence)[-1]
+ values_size = shape_internal(values)[-1]
+ sorted_sequence_2d = reshape(sorted_sequence, [-1, sequence_size])
+ values_2d = reshape(values, [-1, values_size])
+ if side == "right":
+ output = gen_array_ops.upper_bound(sorted_sequence_2d, values_2d, out_type,
+ name)
+ elif side == "left":
+ output = gen_array_ops.lower_bound(sorted_sequence_2d, values_2d, out_type,
+ name)
+ else:
+ raise ValueError("side must be either 'right' or 'left'. Saw: %s." % side)
+ return reshape(output, shape_internal(values))
+
+
quantize.__doc__ = gen_array_ops.quantize_v2.__doc__
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py
index c6a6b2a7fa..f8b1ddb140 100644
--- a/tensorflow/python/ops/cond_v2_impl.py
+++ b/tensorflow/python/ops/cond_v2_impl.py
@@ -119,7 +119,11 @@ def cond_v2(pred, true_fn, false_fn, name="cond"):
attr_value_pb2.AttrValue(b=True))
# pylint: enable=protected-access
- return tuple(tensors[:num_cond_outputs])
+ result = tuple(tensors[:num_cond_outputs])
+ if len(result) == 1:
+ return result[0]
+ else:
+ return result
@ops.RegisterGradient("If")
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 0e20fadb2b..9d7d31df22 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -60,8 +60,17 @@ from tensorflow.python.util import nest
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.tf_export import tf_export
+# The while_v2 module.
+_while_v2 = None
ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0"
+# Note: Setting this to True is not sufficient to switch to the v2 while_loop.
+# Users must also import the while_v2 module to set the _while_v2 module
+# variable above. We do this to avoid a circular dependency:
+# control_flow_ops -> while_v2 -> gradients_impl -> control_flow_ops
+# A ValueError is raised in tf.while_loop if this is set to True and the
+# `_while_v2` module is not set.
+ENABLE_WHILE_V2 = os.getenv("TF_ENABLE_WHILE_V2", "0") != "0"
# We override the 'tuple' for a control flow op, so we keep python's
@@ -610,9 +619,10 @@ def _EnforceShapeInvariant(merge_var, next_var):
"less-specific shape." %
(input_t.name, input_t.shape, n_shape))
else:
- if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
- raise TypeError("Type %s not supported" % type(var))
- if isinstance(var, ops.IndexedSlices):
+ if not isinstance(merge_var,
+ (ops.IndexedSlices, sparse_tensor.SparseTensor)):
+ raise TypeError("Type %s not supported" % type(merge_var))
+ if isinstance(merge_var, ops.IndexedSlices):
m_values_shape = merge_var.values.get_shape()
m_indices_shape = merge_var.indices.get_shape()
m_shape_shape = tensor_shape.TensorShape(None)
@@ -3210,6 +3220,13 @@ def while_loop(cond,
```
"""
+ if ENABLE_WHILE_V2 and not context.executing_eagerly():
+ if not _while_v2:
+ raise ValueError("The while_v2 module is not set. Did you forget to "
+ "import tensorflow.python.ops."
+ "while_v2?")
+ return _while_v2.while_loop(cond, body, loop_vars, name)
+
with ops.name_scope(name, "while", loop_vars):
if not loop_vars:
raise ValueError("No loop variables provided")
diff --git a/tensorflow/python/ops/distributions/beta.py b/tensorflow/python/ops/distributions/beta.py
index 99d30b0bd1..2ba1ea6744 100644
--- a/tensorflow/python/ops/distributions/beta.py
+++ b/tensorflow/python/ops/distributions/beta.py
@@ -98,10 +98,13 @@ class Beta(distribution.Distribution):
#### Examples
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Create a batch of three Beta distributions.
alpha = [1, 2, 3]
beta = [1, 2, 3]
- dist = tf.distributions.Beta(alpha, beta)
+ dist = tfd.Beta(alpha, beta)
dist.sample([4, 5]) # Shape [4, 5, 3]
@@ -117,7 +120,7 @@ class Beta(distribution.Distribution):
# Create batch_shape=[2, 3] via parameter broadcast:
alpha = [[1.], [2]] # Shape [2, 1]
beta = [3., 4, 5] # Shape [3]
- dist = tf.distributions.Beta(alpha, beta)
+ dist = tfd.Beta(alpha, beta)
# alpha broadcast as: [[1., 1, 1,],
# [2, 2, 2]]
@@ -138,7 +141,7 @@ class Beta(distribution.Distribution):
```python
alpha = tf.constant(1.0)
beta = tf.constant(2.0)
- dist = tf.distributions.Beta(alpha, beta)
+ dist = tfd.Beta(alpha, beta)
samples = dist.sample(5) # Shape [5]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
diff --git a/tensorflow/python/ops/distributions/bijector_impl.py b/tensorflow/python/ops/distributions/bijector_impl.py
index 2e7aa30296..9c63385dd0 100644
--- a/tensorflow/python/ops/distributions/bijector_impl.py
+++ b/tensorflow/python/ops/distributions/bijector_impl.py
@@ -825,10 +825,21 @@ class Bijector(object):
min_event_ndims=self.inverse_min_event_ndims,
event_ndims=event_ndims)):
if not self._is_injective: # No caching for non-injective
- ildjs = self._inverse_log_det_jacobian(y, **kwargs)
- return tuple(self._reduce_jacobian_det_over_event(
- y, ildj, self.inverse_min_event_ndims, event_ndims)
- for ildj in ildjs)
+ try:
+ ildjs = self._inverse_log_det_jacobian(y, **kwargs)
+ return tuple(self._reduce_jacobian_det_over_event(
+ y, ildj, self.inverse_min_event_ndims, event_ndims)
+ for ildj in ildjs)
+ except NotImplementedError as original_exception:
+ try:
+ x = self._inverse(y, **kwargs)
+ fldjs = self._forward_log_det_jacobian(x, **kwargs)
+ return tuple(self._reduce_jacobian_det_over_event(
+ x, -fldj, self.forward_min_event_ndims, event_ndims)
+ for fldj in fldjs)
+ except NotImplementedError:
+ raise original_exception
+
mapping = self._lookup(y=y, kwargs=kwargs)
if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
return mapping.ildj_map[event_ndims]
@@ -917,11 +928,21 @@ class Bijector(object):
return -1. * self._constant_ildj_map[event_ndims]
x = ops.convert_to_tensor(x, name="x")
self._maybe_assert_dtype(x)
- if not self._is_injective:
- fldjs = self._forward_log_det_jacobian(x, **kwargs) # No caching.
- return tuple(self._reduce_jacobian_det_over_event(
- x, fldj, self.forward_min_event_ndims, event_ndims)
- for fldj in fldjs)
+ if not self._is_injective: # No caching for non-injective
+ try:
+ fldjs = self._forward_log_det_jacobian(x, **kwargs) # No caching.
+ return tuple(self._reduce_jacobian_det_over_event(
+ x, fldj, self.forward_min_event_ndims, event_ndims)
+ for fldj in fldjs)
+ except NotImplementedError as original_exception:
+ try:
+ y = self._forward(x, **kwargs)
+ ildjs = self._inverse_log_det_jacobian(y, **kwargs)
+ return tuple(self._reduce_jacobian_det_over_event(
+ y, -ildj, self.inverse_min_event_ndims, event_ndims)
+ for ildj in ildjs)
+ except NotImplementedError:
+ raise original_exception
mapping = self._lookup(x=x, kwargs=kwargs)
if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
return -mapping.ildj_map[event_ndims]
diff --git a/tensorflow/python/ops/distributions/dirichlet.py b/tensorflow/python/ops/distributions/dirichlet.py
index 9104a1d071..415249a958 100644
--- a/tensorflow/python/ops/distributions/dirichlet.py
+++ b/tensorflow/python/ops/distributions/dirichlet.py
@@ -104,10 +104,13 @@ class Dirichlet(distribution.Distribution):
#### Examples
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Create a single trivariate Dirichlet, with the 3rd class being three times
# more frequent than the first. I.e., batch_shape=[], event_shape=[3].
alpha = [1., 2, 3]
- dist = tf.distributions.Dirichlet(alpha)
+ dist = tfd.Dirichlet(alpha)
dist.sample([4, 5]) # shape: [4, 5, 3]
@@ -129,7 +132,7 @@ class Dirichlet(distribution.Distribution):
# Create batch_shape=[2], event_shape=[3]:
alpha = [[1., 2, 3],
[4, 5, 6]] # shape: [2, 3]
- dist = tf.distributions.Dirichlet(alpha)
+ dist = tfd.Dirichlet(alpha)
dist.sample([4, 5]) # shape: [4, 5, 2, 3]
@@ -144,7 +147,7 @@ class Dirichlet(distribution.Distribution):
```python
alpha = tf.constant([1.0, 2.0, 3.0])
- dist = tf.distributions.Dirichlet(alpha)
+ dist = tfd.Dirichlet(alpha)
samples = dist.sample(5) # Shape [5, 3]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index 578e7b7dd2..12fd039392 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -25,6 +25,7 @@ import types
import numpy as np
import six
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -127,6 +128,18 @@ def _update_docstring(old_str, append_str):
return old_str + "\n\n" + append_str
+def _convert_to_tensor(value, name=None, preferred_dtype=None):
+ """Converts to tensor avoiding an eager bug that loses float precision."""
+ # TODO(b/116672045): Remove this function.
+ if (context.executing_eagerly() and preferred_dtype is not None and
+ (preferred_dtype.is_integer or preferred_dtype.is_bool)):
+ v = ops.convert_to_tensor(value, name=name)
+ if v.dtype.is_floating:
+ return v
+ return ops.convert_to_tensor(
+ value, name=name, preferred_dtype=preferred_dtype)
+
+
class _DistributionMeta(abc.ABCMeta):
def __new__(mcs, classname, baseclasses, attrs):
@@ -601,7 +614,8 @@ class Distribution(_BaseDistribution):
return type(self)(**parameters)
def _batch_shape_tensor(self):
- raise NotImplementedError("batch_shape_tensor is not implemented")
+ raise NotImplementedError(
+ "batch_shape_tensor is not implemented: {}".format(type(self).__name__))
def batch_shape_tensor(self, name="batch_shape_tensor"):
"""Shape of a single sample from a single event index as a 1-D `Tensor`.
@@ -640,7 +654,8 @@ class Distribution(_BaseDistribution):
return tensor_shape.as_shape(self._batch_shape())
def _event_shape_tensor(self):
- raise NotImplementedError("event_shape_tensor is not implemented")
+ raise NotImplementedError(
+ "event_shape_tensor is not implemented: {}".format(type(self).__name__))
def event_shape_tensor(self, name="event_shape_tensor"):
"""Shape of a single sample from a single batch as a 1-D int32 `Tensor`.
@@ -701,7 +716,8 @@ class Distribution(_BaseDistribution):
name="is_scalar_batch")
def _sample_n(self, n, seed=None):
- raise NotImplementedError("sample_n is not implemented")
+ raise NotImplementedError("sample_n is not implemented: {}".format(
+ type(self).__name__))
def _call_sample_n(self, sample_shape, seed, name, **kwargs):
with self._name_scope(name, values=[sample_shape]):
@@ -733,15 +749,20 @@ class Distribution(_BaseDistribution):
return self._call_sample_n(sample_shape, seed, name)
def _log_prob(self, value):
- raise NotImplementedError("log_prob is not implemented")
+ raise NotImplementedError("log_prob is not implemented: {}".format(
+ type(self).__name__))
def _call_log_prob(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._log_prob(value, **kwargs)
- except NotImplementedError:
- return math_ops.log(self._prob(value, **kwargs))
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.log(self._prob(value, **kwargs))
+ except NotImplementedError:
+ raise original_exception
def log_prob(self, value, name="log_prob"):
"""Log probability density/mass function.
@@ -757,15 +778,20 @@ class Distribution(_BaseDistribution):
return self._call_log_prob(value, name)
def _prob(self, value):
- raise NotImplementedError("prob is not implemented")
+ raise NotImplementedError("prob is not implemented: {}".format(
+ type(self).__name__))
def _call_prob(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._prob(value, **kwargs)
- except NotImplementedError:
- return math_ops.exp(self._log_prob(value, **kwargs))
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.exp(self._log_prob(value, **kwargs))
+ except NotImplementedError:
+ raise original_exception
def prob(self, value, name="prob"):
"""Probability density/mass function.
@@ -781,15 +807,20 @@ class Distribution(_BaseDistribution):
return self._call_prob(value, name)
def _log_cdf(self, value):
- raise NotImplementedError("log_cdf is not implemented")
+ raise NotImplementedError("log_cdf is not implemented: {}".format(
+ type(self).__name__))
def _call_log_cdf(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._log_cdf(value, **kwargs)
- except NotImplementedError:
- return math_ops.log(self._cdf(value, **kwargs))
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.log(self._cdf(value, **kwargs))
+ except NotImplementedError:
+ raise original_exception
def log_cdf(self, value, name="log_cdf"):
"""Log cumulative distribution function.
@@ -815,15 +846,20 @@ class Distribution(_BaseDistribution):
return self._call_log_cdf(value, name)
def _cdf(self, value):
- raise NotImplementedError("cdf is not implemented")
+ raise NotImplementedError("cdf is not implemented: {}".format(
+ type(self).__name__))
def _call_cdf(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._cdf(value, **kwargs)
- except NotImplementedError:
- return math_ops.exp(self._log_cdf(value, **kwargs))
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.exp(self._log_cdf(value, **kwargs))
+ except NotImplementedError:
+ raise original_exception
def cdf(self, value, name="cdf"):
"""Cumulative distribution function.
@@ -845,15 +881,21 @@ class Distribution(_BaseDistribution):
return self._call_cdf(value, name)
def _log_survival_function(self, value):
- raise NotImplementedError("log_survival_function is not implemented")
+ raise NotImplementedError(
+ "log_survival_function is not implemented: {}".format(
+ type(self).__name__))
def _call_log_survival_function(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._log_survival_function(value, **kwargs)
- except NotImplementedError:
- return math_ops.log1p(-self.cdf(value, **kwargs))
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.log1p(-self.cdf(value, **kwargs))
+ except NotImplementedError:
+ raise original_exception
def log_survival_function(self, value, name="log_survival_function"):
"""Log survival function.
@@ -880,15 +922,20 @@ class Distribution(_BaseDistribution):
return self._call_log_survival_function(value, name)
def _survival_function(self, value):
- raise NotImplementedError("survival_function is not implemented")
+ raise NotImplementedError("survival_function is not implemented: {}".format(
+ type(self).__name__))
def _call_survival_function(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._survival_function(value, **kwargs)
- except NotImplementedError:
- return 1. - self.cdf(value, **kwargs)
+ except NotImplementedError as original_exception:
+ try:
+ return 1. - self.cdf(value, **kwargs)
+ except NotImplementedError:
+ raise original_exception
def survival_function(self, value, name="survival_function"):
"""Survival function.
@@ -912,7 +959,8 @@ class Distribution(_BaseDistribution):
return self._call_survival_function(value, name)
def _entropy(self):
- raise NotImplementedError("entropy is not implemented")
+ raise NotImplementedError("entropy is not implemented: {}".format(
+ type(self).__name__))
def entropy(self, name="entropy"):
"""Shannon entropy in nats."""
@@ -920,7 +968,8 @@ class Distribution(_BaseDistribution):
return self._entropy()
def _mean(self):
- raise NotImplementedError("mean is not implemented")
+ raise NotImplementedError("mean is not implemented: {}".format(
+ type(self).__name__))
def mean(self, name="mean"):
"""Mean."""
@@ -928,11 +977,13 @@ class Distribution(_BaseDistribution):
return self._mean()
def _quantile(self, value):
- raise NotImplementedError("quantile is not implemented")
+ raise NotImplementedError("quantile is not implemented: {}".format(
+ type(self).__name__))
def _call_quantile(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
return self._quantile(value, **kwargs)
def quantile(self, value, name="quantile"):
@@ -955,7 +1006,8 @@ class Distribution(_BaseDistribution):
return self._call_quantile(value, name)
def _variance(self):
- raise NotImplementedError("variance is not implemented")
+ raise NotImplementedError("variance is not implemented: {}".format(
+ type(self).__name__))
def variance(self, name="variance"):
"""Variance.
@@ -979,11 +1031,15 @@ class Distribution(_BaseDistribution):
with self._name_scope(name):
try:
return self._variance()
- except NotImplementedError:
- return math_ops.square(self._stddev())
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.square(self._stddev())
+ except NotImplementedError:
+ raise original_exception
def _stddev(self):
- raise NotImplementedError("stddev is not implemented")
+ raise NotImplementedError("stddev is not implemented: {}".format(
+ type(self).__name__))
def stddev(self, name="stddev"):
"""Standard deviation.
@@ -1008,11 +1064,15 @@ class Distribution(_BaseDistribution):
with self._name_scope(name):
try:
return self._stddev()
- except NotImplementedError:
- return math_ops.sqrt(self._variance())
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.sqrt(self._variance())
+ except NotImplementedError:
+ raise original_exception
def _covariance(self):
- raise NotImplementedError("covariance is not implemented")
+ raise NotImplementedError("covariance is not implemented: {}".format(
+ type(self).__name__))
def covariance(self, name="covariance"):
"""Covariance.
@@ -1054,7 +1114,8 @@ class Distribution(_BaseDistribution):
return self._covariance()
def _mode(self):
- raise NotImplementedError("mode is not implemented")
+ raise NotImplementedError("mode is not implemented: {}".format(
+ type(self).__name__))
def mode(self, name="mode"):
"""Mode."""
@@ -1080,7 +1141,7 @@ class Distribution(_BaseDistribution):
where `F` denotes the support of the random variable `X ~ P`.
Args:
- other: `tf.distributions.Distribution` instance.
+ other: `tfp.distributions.Distribution` instance.
name: Python `str` prepended to names of ops created by this function.
Returns:
@@ -1111,7 +1172,7 @@ class Distribution(_BaseDistribution):
denotes (Shanon) cross entropy, and `H[.]` denotes (Shanon) entropy.
Args:
- other: `tf.distributions.Distribution` instance.
+ other: `tfp.distributions.Distribution` instance.
name: Python `str` prepended to names of ops created by this function.
Returns:
@@ -1123,7 +1184,7 @@ class Distribution(_BaseDistribution):
return self._kl_divergence(other)
def __str__(self):
- return ("tf.distributions.{type_name}("
+ return ("tfp.distributions.{type_name}("
"\"{self_name}\""
"{maybe_batch_shape}"
"{maybe_event_shape}"
@@ -1139,7 +1200,7 @@ class Distribution(_BaseDistribution):
dtype=self.dtype.name))
def __repr__(self):
- return ("<tf.distributions.{type_name} "
+ return ("<tfp.distributions.{type_name} "
"'{self_name}'"
" batch_shape={batch_shape}"
" event_shape={event_shape}"
diff --git a/tensorflow/python/ops/distributions/gamma.py b/tensorflow/python/ops/distributions/gamma.py
index b631f0247c..3293cda874 100644
--- a/tensorflow/python/ops/distributions/gamma.py
+++ b/tensorflow/python/ops/distributions/gamma.py
@@ -100,8 +100,11 @@ class Gamma(distribution.Distribution):
#### Examples
```python
- dist = tf.distributions.Gamma(concentration=3.0, rate=2.0)
- dist2 = tf.distributions.Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
+ dist = tfd.Gamma(concentration=3.0, rate=2.0)
+ dist2 = tfd.Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
```
Compute the gradients of samples w.r.t. the parameters:
@@ -109,7 +112,7 @@ class Gamma(distribution.Distribution):
```python
concentration = tf.constant(3.0)
rate = tf.constant(2.0)
- dist = tf.distributions.Gamma(concentration, rate)
+ dist = tfd.Gamma(concentration, rate)
samples = dist.sample(5) # Shape [5]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
diff --git a/tensorflow/python/ops/distributions/kullback_leibler.py b/tensorflow/python/ops/distributions/kullback_leibler.py
index e3c6f3e789..fdeb97bf64 100644
--- a/tensorflow/python/ops/distributions/kullback_leibler.py
+++ b/tensorflow/python/ops/distributions/kullback_leibler.py
@@ -127,8 +127,8 @@ def cross_entropy(ref, other,
where `F` denotes the support of the random variable `X ~ P`.
Args:
- ref: `tf.distributions.Distribution` instance.
- other: `tf.distributions.Distribution` instance.
+ ref: `tfd.Distribution` instance.
+ other: `tfd.Distribution` instance.
allow_nan_stats: Python `bool`, default `True`. When `True`,
statistics (e.g., mean, mode, variance) use the value "`NaN`" to
indicate the result is undefined. When `False`, an exception is raised
diff --git a/tensorflow/python/ops/distributions/normal.py b/tensorflow/python/ops/distributions/normal.py
index d0a987ba7c..2feaf806c0 100644
--- a/tensorflow/python/ops/distributions/normal.py
+++ b/tensorflow/python/ops/distributions/normal.py
@@ -71,15 +71,18 @@ class Normal(distribution.Distribution):
Examples of initialization of one or a batch of distributions.
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Define a single scalar Normal distribution.
- dist = tf.distributions.Normal(loc=0., scale=3.)
+ dist = tfd.Normal(loc=0., scale=3.)
# Evaluate the cdf at 1, returning a scalar.
dist.cdf(1.)
# Define a batch of two scalar valued Normals.
# The first has mean 1 and standard deviation 11, the second 2 and 22.
- dist = tf.distributions.Normal(loc=[1, 2.], scale=[11, 22.])
+ dist = tfd.Normal(loc=[1, 2.], scale=[11, 22.])
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
# returning a length two tensor.
@@ -94,7 +97,7 @@ class Normal(distribution.Distribution):
```python
# Define a batch of two scalar valued Normals.
# Both have mean 1, but different standard deviations.
- dist = tf.distributions.Normal(loc=1., scale=[11, 22.])
+ dist = tfd.Normal(loc=1., scale=[11, 22.])
# Evaluate the pdf of both distributions on the same point, 3.0,
# returning a length 2 tensor.
diff --git a/tensorflow/python/ops/distributions/student_t.py b/tensorflow/python/ops/distributions/student_t.py
index e0cf6f86f1..e8d214bbe0 100644
--- a/tensorflow/python/ops/distributions/student_t.py
+++ b/tensorflow/python/ops/distributions/student_t.py
@@ -91,8 +91,11 @@ class StudentT(distribution.Distribution):
Examples of initialization of one or a batch of distributions.
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Define a single scalar Student t distribution.
- single_dist = tf.distributions.StudentT(df=3)
+ single_dist = tfd.StudentT(df=3)
# Evaluate the pdf at 1, returning a scalar Tensor.
single_dist.prob(1.)
@@ -100,9 +103,7 @@ class StudentT(distribution.Distribution):
# Define a batch of two scalar valued Student t's.
# The first has degrees of freedom 2, mean 1, and scale 11.
# The second 3, 2 and 22.
- multi_dist = tf.distributions.StudentT(df=[2, 3],
- loc=[1, 2.],
- scale=[11, 22.])
+ multi_dist = tfd.StudentT(df=[2, 3], loc=[1, 2.], scale=[11, 22.])
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
# returning a length two tensor.
@@ -117,7 +118,7 @@ class StudentT(distribution.Distribution):
```python
# Define a batch of two Student's t distributions.
# Both have df 2 and mean 1, but different scales.
- dist = tf.distributions.StudentT(df=2, loc=1, scale=[11, 22.])
+ dist = tfd.StudentT(df=2, loc=1, scale=[11, 22.])
# Evaluate the pdf of both distributions on the same point, 3.0,
# returning a length 2 tensor.
@@ -130,7 +131,7 @@ class StudentT(distribution.Distribution):
df = tf.constant(2.0)
loc = tf.constant(2.0)
scale = tf.constant(11.0)
- dist = tf.distributions.StudentT(df=df, loc=loc, scale=scale)
+ dist = tfd.StudentT(df=df, loc=loc, scale=scale)
samples = dist.sample(5) # Shape [5]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
@@ -138,7 +139,6 @@ class StudentT(distribution.Distribution):
```
"""
- # pylint: enable=line-too-long
def __init__(self,
df,
diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py
index 3e480a79f5..ad848dfee6 100644
--- a/tensorflow/python/ops/distributions/util.py
+++ b/tensorflow/python/ops/distributions/util.py
@@ -155,7 +155,8 @@ def get_logits_and_probs(logits=None,
probs=None,
multidimensional=False,
validate_args=False,
- name="get_logits_and_probs"):
+ name="get_logits_and_probs",
+ dtype=None):
"""Converts logit to probabilities (or vice-versa), and returns both.
Args:
@@ -169,6 +170,7 @@ def get_logits_and_probs(logits=None,
`0 <= probs <= 1` (if not `multidimensional`) or that the last dimension
of `probs` sums to one.
name: A name for this operation (optional).
+ dtype: `tf.DType` to prefer when converting args to `Tensor`s.
Returns:
logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or
@@ -183,7 +185,7 @@ def get_logits_and_probs(logits=None,
raise ValueError("Must pass probs or logits, but not both.")
if probs is None:
- logits = ops.convert_to_tensor(logits, name="logits")
+ logits = ops.convert_to_tensor(logits, name="logits", dtype=dtype)
if not logits.dtype.is_floating:
raise TypeError("logits must having floating type.")
# We can early return since we constructed probs and therefore know
@@ -194,7 +196,7 @@ def get_logits_and_probs(logits=None,
return logits, nn.softmax(logits, name="probs")
return logits, math_ops.sigmoid(logits, name="probs")
- probs = ops.convert_to_tensor(probs, name="probs")
+ probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
if not probs.dtype.is_floating:
raise TypeError("probs must having floating type.")
@@ -524,6 +526,8 @@ def matrix_diag_transform(matrix, transform=None, name=None):
Example of heteroskedastic 2-D linear regression.
```python
+ tfd = tfp.distributions
+
# Get a trainable Cholesky factor.
matrix_values = tf.contrib.layers.fully_connected(activations, 4)
matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
@@ -533,7 +537,7 @@ def matrix_diag_transform(matrix, transform=None, name=None):
mu = tf.contrib.layers.fully_connected(activations, 2)
# This is a fully trainable multivariate normal!
- dist = tf.contrib.distributions.MVNCholesky(mu, chol)
+ dist = tfd.MultivariateNormalTriL(mu, chol)
# Standard log loss. Minimizing this will "train" mu and chol, and then dist
# will be a distribution predicting labels as multivariate Gaussians.
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index a4e7c84ae4..119d9522bd 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -41,6 +41,7 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.gen_functional_ops import remote_call
# pylint: enable=unused-import
from tensorflow.python.ops.gen_functional_ops import symbolic_gradient
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@@ -263,7 +264,7 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
@tf_export("map_fn")
-def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
+def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True,
swap_memory=False, infer_shape=True, name=None):
"""map on the list of tensors unpacked from `elems` on dimension 0.
@@ -305,6 +306,25 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
instead.
+ When executing eagerly, map_fn does not execute in parallel even if
+ `parallel_iterations` is set to a value > 1. You can still get the
+ performance benefits of running a function in parallel by using the
+ `tf.contrib.eager.defun` decorator,
+
+ ```python
+ # Assume the function being used in map_fn is fn.
+ # To ensure map_fn calls fn in parallel, use the defun decorator.
+ @tf.contrib.eager.defun
+ def func(tensor):
+ return tf.map_fn(fn, tensor)
+ ```
+
+ Note that if you use the defun decorator, any non-TensorFlow Python code
+ that you may have written in your function won't get executed. See
+ `tf.contrib.eager.defun` for more details. The recommendation would be to
+ debug without defun but switch to defun to get performance benefits of
+ running map_fn in parallel.
+
Args:
fn: The callable to be performed. It accepts one argument, which will
have the same (possibly nested) structure as `elems`. Its output
@@ -317,7 +337,8 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
of Tensors differing from the structure of `elems`, then `dtype` is not
optional and must have the same structure as the output of `fn`.
parallel_iterations: (optional) The number of iterations allowed to run
- in parallel.
+ in parallel. When graph building, the default value is 10. While executing
+ eagerly, the default value is set to 1.
back_prop: (optional) True enables support for back propagation.
swap_memory: (optional) True enables GPU-CPU memory swapping.
infer_shape: (optional) False disables tests for consistent output shapes.
@@ -363,6 +384,20 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
" SparseTensor(input.indices, map_fn(fn, input.values), "
"input.dense_shape)")
+ in_graph_mode = not context.executing_eagerly()
+ # Set the default number of parallel_iterations depending on graph/eager mode.
+ if in_graph_mode and not parallel_iterations:
+ parallel_iterations = 10
+ elif not in_graph_mode and not parallel_iterations:
+ parallel_iterations = 1
+
+ if not in_graph_mode and parallel_iterations > 1:
+ logging.log_first_n(logging.WARN, "Setting parallel_iterations > 1 has no "
+ "effect when executing eagerly. Consider calling map_fn"
+ " with tf.contrib.eager.defun to execute fn in "
+ "parallel.", 1)
+ parallel_iterations = 1
+
input_is_sequence = nest.is_sequence(elems)
input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]
def input_pack(x):
@@ -381,7 +416,6 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
elems_flat = input_flatten(elems)
- in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "map", elems_flat):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 4f6e5dc473..3c9b7a01c7 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -273,7 +273,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
def testVariableRefGradient(self):
with ops.Graph().as_default():
init = constant_op.constant(100.0)
- var = variables.Variable(init)
+ var = variables.VariableV1(init)
gradient = gradients.gradients(var._ref(), var)
self.assertIsNotNone(gradient)
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index de260f3140..1c75aab578 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -29,7 +29,6 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
@@ -301,21 +300,21 @@ def random_flip_left_right(image, seed=None):
def _random_flip(image, flip_index, seed, scope_name):
"""Randomly (50% chance) flip an image along axis `flip_index`.
- Args:
- image: 4-D Tensor of shape `[batch, height, width, channels]` or
- 3-D Tensor of shape `[height, width, channels]`.
- flip_index: The dimension along which to flip the image.
- Vertical: 0, Horizontal: 1
- seed: A Python integer. Used to create a random seed. See
- `tf.set_random_seed`
- for behavior.
- scope_name: Name of the scope in which the ops are added.
- Returns:
- A tensor of the same type and shape as `image`.
+ Args:
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
+ flip_index: Dimension along which to flip image. Vertical: 0, Horizontal: 1
+ seed: A Python integer. Used to create a random seed. See
+ `tf.set_random_seed`
+ for behavior.
+ scope_name: Name of the scope in which the ops are added.
- Raises:
- ValueError: if the shape of `image` not supported.
+ Returns:
+ A tensor of the same type and shape as `image`.
+
+ Raises:
+ ValueError: if the shape of `image` not supported.
"""
with ops.name_scope(None, scope_name, [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
@@ -330,19 +329,18 @@ def _random_flip(image, flip_index, seed, scope_name):
lambda: image,
name=scope
)
- if isinstance(result, tuple):
- result = result[0] # TODO(b/111124878) remove this logic (CondV2).
return fix_image_flip_shape(image, result)
elif shape.ndims == 4:
+ batch_size = array_ops.shape(image)[0]
uniform_random = random_ops.random_uniform(
- [array_ops.shape(image)[0]], 0, 1.0, seed=seed
+ [batch_size], 0, 1.0, seed=seed
)
- mirror_cond = math_ops.less(uniform_random, .5)
- return array_ops.where(
- mirror_cond,
- image,
- functional_ops.map_fn(lambda x: array_ops.reverse(x, [flip_index]), image, dtype=image.dtype)
+ flips = math_ops.round(
+ array_ops.reshape(uniform_random, [batch_size, 1, 1, 1])
)
+ flips = math_ops.cast(flips, image.dtype)
+ flipped_input = array_ops.reverse(image, [flip_index + 1])
+ return flips * flipped_input + (1 - flips) * image
else:
raise ValueError('\'image\' must have either 3 or 4 dimensions.')
@@ -1029,10 +1027,10 @@ def resize_images(images,
scale_factor_width = (math_ops.to_float(new_width_const) /
math_ops.to_float(current_width))
scale_factor = math_ops.minimum(scale_factor_height, scale_factor_width)
- scaled_height_const = math_ops.to_int32(scale_factor *
- math_ops.to_float(current_height))
- scaled_width_const = math_ops.to_int32(scale_factor *
- math_ops.to_float(current_width))
+ scaled_height_const = math_ops.to_int32(
+ math_ops.round(scale_factor * math_ops.to_float(current_height)))
+ scaled_width_const = math_ops.to_int32(
+ math_ops.round(scale_factor * math_ops.to_float(current_width)))
# NOTE: Reset the size and other constants used later.
size = ops.convert_to_tensor([scaled_height_const, scaled_width_const],
@@ -1176,7 +1174,7 @@ def resize_image_with_pad(image,
@tf_export('image.per_image_standardization')
def per_image_standardization(image):
- """Linearly scales `image` to have zero mean and unit norm.
+ """Linearly scales `image` to have zero mean and unit variance.
This op computes `(x - mean) / adjusted_stddev`, where `mean` is the average
of all values in image, and
@@ -1379,7 +1377,7 @@ def adjust_gamma(image, gamma=1, gain=1):
[1] http://en.wikipedia.org/wiki/Gamma_correction
"""
- with ops.op_scope([image, gamma, gain], None, 'adjust_gamma'):
+ with ops.name_scope(None, 'adjust_gamma', [image, gamma, gain]) as name:
# Convert pixel value to DT_FLOAT for computing adjusted image.
img = ops.convert_to_tensor(image, name='img', dtype=dtypes.float32)
# Keep image dtype for computing the scale of corresponding dtype.
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 795e6bbc3e..35fdee4fad 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -2687,6 +2687,12 @@ class ResizeImagesTest(test_util.TensorFlowTestCase):
self._assertResizeCheckShape(x, x_shape, [3840, 2160], [3840, 2160, 3])
+ def testPreserveAspectRatioSquare(self):
+ x_shape = [299, 299, 3]
+ x = np.random.uniform(size=x_shape)
+
+ self._assertResizeCheckShape(x, x_shape, [320, 320], [320, 320, 3])
+
class ResizeImageWithPadTest(test_util.TensorFlowTestCase):
@@ -3667,7 +3673,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
# Note: There are multiple versions of non_max_suppression v2, v3, v4.
# gen_image_ops.non_max_suppression_v2:
for dtype in [np.float16, np.float32]:
- with self.test_session():
+ with self.cached_session():
boxes = constant_op.constant(boxes_np, dtype=dtype)
scores = constant_op.constant(scores_np, dtype=dtype)
max_output_size = constant_op.constant(max_output_size_np)
@@ -3677,7 +3683,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
self.assertAllClose(selected_indices, [3, 0, 5])
# image_ops.non_max_suppression = gen_image_ops.non_max_suppression_v3.
for dtype in [np.float16, np.float32]:
- with self.test_session():
+ with self.cached_session():
boxes = constant_op.constant(boxes_np, dtype=dtype)
scores = constant_op.constant(scores_np, dtype=dtype)
max_output_size = constant_op.constant(max_output_size_np)
@@ -3688,7 +3694,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
# gen_image_ops.non_max_suppression_v4.
score_threshold = float('-inf')
for dtype in [np.float16, np.float32]:
- with self.test_session():
+ with self.cached_session():
boxes = constant_op.constant(boxes_np, dtype=dtype)
scores = constant_op.constant(scores_np, dtype=dtype)
max_output_size = constant_op.constant(max_output_size_np)
diff --git a/tensorflow/python/ops/linalg/linear_operator_test_util.py b/tensorflow/python/ops/linalg/linear_operator_test_util.py
index 78c85db557..76d659f109 100644
--- a/tensorflow/python/ops/linalg/linear_operator_test_util.py
+++ b/tensorflow/python/ops/linalg/linear_operator_test_util.py
@@ -184,7 +184,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -199,7 +199,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -215,7 +215,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -240,7 +240,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for dtype in self._dtypes_to_test:
for adjoint in self._adjoint_options:
for adjoint_arg in self._adjoint_arg_options:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -283,7 +283,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for dtype in self._dtypes_to_test:
for adjoint in self._adjoint_options:
for adjoint_arg in self._adjoint_arg_options:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -319,7 +319,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -335,7 +335,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -353,7 +353,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py
index df41933f8a..4c53f33af1 100644
--- a/tensorflow/python/ops/logging_ops.py
+++ b/tensorflow/python/ops/logging_ops.py
@@ -19,13 +19,24 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import pprint
+import random
+import sys
+
+import six
+
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import gen_logging_ops
+from tensorflow.python.ops import string_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_logging_ops import *
# pylint: enable=wildcard-import
+from tensorflow.python.platform import tf_logging
+from tensorflow.python.util import nest
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
@@ -40,7 +51,32 @@ from tensorflow.python.util.tf_export import tf_export
# For users with Python 3 or Python 2.7
# with `from __future__ import print_function`, we could also allow lowercase.
# See https://github.com/tensorflow/tensorflow/issues/18053
-@tf_export("Print")
+
+
+# pylint: disable=invalid-name
+@deprecated("2018-08-20", "Use tf.print instead of tf.Print. Note that "
+ "tf.print returns a no-output operator that directly "
+ "prints the output. Outside of defuns or eager mode, "
+ "this operator will not be executed unless it is "
+ "directly specified in session.run or used as a "
+ "control dependency for other operators. This is "
+ "only a concern in graph mode. Below is an example "
+ "of how to ensure tf.print executes in graph mode:\n"
+ """```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor = tf.range(10)
+ print_op = tf.print(tensor)
+ with tf.control_dependencies([print_op]):
+ out = tf.add(tensor, tensor)
+ sess.run(out)
+ ```
+Additionally, to use tf.print in python 2.7, users must make sure to import
+the following:
+
+ `from __future__ import print_function`
+""")
+@tf_export(v1=["Print"])
def Print(input_, data, message=None, first_n=None, summarize=None,
name=None):
"""Prints a list of tensors.
@@ -66,6 +102,228 @@ def Print(input_, data, message=None, first_n=None, summarize=None,
A `Tensor`. Has the same type and contents as `input_`.
"""
return gen_logging_ops._print(input_, data, message, first_n, summarize, name)
+# pylint: enable=invalid-name
+
+
+def _generate_placeholder_string(x, default_placeholder="{}"):
+ """Generate and return a string that does not appear in `x`."""
+ placeholder = default_placeholder
+ rng = random.Random(5)
+ while placeholder in x:
+ placeholder = placeholder + str(rng.randint(0, 9))
+ return placeholder
+
+
+# Temporarily disable pylint g-doc-args error to allow giving more context
+# about what the kwargs are.
+# Because we are using arbitrary-length positional arguments, python 2
+# does not support explicitly specifying the keyword arguments in the
+# function definition.
+# pylint: disable=g-doc-args
+@tf_export("print")
+def print_v2(*inputs, **kwargs):
+ """Print the specified inputs.
+
+ Returns an operator that prints the specified inputs to a desired
+ output stream or logging level. The inputs may be dense or sparse Tensors,
+ primitive python objects, data structures that contain Tensors, and printable
+ python objects. Printed tensors will recursively show the first and last
+ `summarize` elements of each dimension.
+
+ With eager execution enabled and/or inside a `tf.contrib.eager.defun` this
+ operator will automatically execute, and users only need to call `tf.print`
+ without using the return value. When constructing graphs outside of a
+ `tf.contrib.eager.defun`, one must either include the returned op
+ in the input to `session.run`, or use the operator as a control dependency for
+ executed ops by specifying `with tf.control_dependencies([print_op])`.
+
+ @compatibility(python2)
+ In python 2.7, make sure to import the following:
+ `from __future__ import print_function`
+ @end_compatibility
+
+ Example:
+ Single-input usage:
+ ```python
+ tf.enable_eager_execution()
+ tensor = tf.range(10)
+ tf.print(tensor, output_stream=sys.stderr)
+ ```
+ (This prints "[0 1 2 ... 7 8 9]" to sys.stderr)
+
+ Multi-input usage:
+ ```python
+ tf.enable_eager_execution()
+ tensor = tf.range(10)
+ tf.print("tensors:", tensor, {2: tensor * 2}, output_stream=sys.stdout)
+ ```
+ (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to
+ sys.stdout)
+
+ Usage in a defun:
+ ```python
+ tf.enable_eager_execution()
+
+ @tf.contrib.eager.defun
+ def f():
+ tensor = tf.range(10)
+ tf.print(tensor, output_stream=sys.stderr)
+ return tensor
+
+ range_tensor = f()
+ ```
+ (This prints "[0 1 2 ... 7 8 9]" to sys.stderr)
+
+ Usage when constructing graphs:
+ ```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor = tf.range(10)
+ print_op = tf.print("tensors:", tensor, {2: tensor * 2},
+ output_stream=sys.stdout)
+ with tf.control_dependencies([print_op]):
+ tripled_tensor = tensor * 3
+ sess.run(tripled_tensor)
+ ```
+ (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to
+ sys.stdout)
+
+ Note: This op is only partially compatible with Jupyter notebooks and colabs.
+ Because it prints to the C++ standard out / standard error, this will go
+ in the notebook kernel's console output, not in the notebook cell output.
+
+ Args:
+ *inputs: Positional arguments that are the inputs to print. Inputs in the
+ printed output will be separated by spaces. Inputs may be python
+ primitives, tensors, data structures such as dicts and lists that
+ may contain tensors (with the data structures possibly nested in
+ arbitrary ways), and printable python objects.
+ output_stream: The output stream or logging level to print to. Defaults to
+ sys.stderr, but sys.stdout, tf.logging.info, tf.logging.warning, and
+ tf.logging.error are also supported.
+ summarize: The first and last `summarize` elements within each dimension are
+ recursively printed per Tensor. If None, then the first 3 and last 3
+ elements of each dimension are printed for each tensor. If set to -1, it
+ will print all elements of every tensor.
+ name: A name for the operation (optional).
+
+ Returns:
+ A print operator that prints the specified inputs in the specified output
+ stream or logging level.
+
+ Raises:
+ ValueError: If an unsupported output stream is specified.
+ """
+ # Because we are using arbitrary-length positional arguments, python 2
+ # does not support explicitly specifying the keyword arguments in the
+ # function definition. So, we manually get the keyword arguments w/ default
+ # values here.
+ output_stream = kwargs.pop("output_stream", sys.stderr)
+ name = kwargs.pop("name", None)
+ summarize = kwargs.pop("summarize", 3)
+ if kwargs:
+ raise ValueError("Unrecognized keyword arguments for tf.print: %s" % kwargs)
+ format_name = None
+ if name:
+ format_name = name + "_format"
+
+ # Match the C++ string constants representing the different output streams.
+ # Keep this updated!
+ output_stream_to_constant = {
+ sys.stdout: "stdout",
+ sys.stderr: "stderr",
+ tf_logging.INFO: "log(info)",
+ tf_logging.info: "log(info)",
+ tf_logging.WARN: "log(warning)",
+ tf_logging.warning: "log(warning)",
+ tf_logging.warn: "log(warning)",
+ tf_logging.ERROR: "log(error)",
+ tf_logging.error: "log(error)",
+ }
+
+ output_stream_string = output_stream_to_constant.get(output_stream)
+ if not output_stream_string:
+ raise ValueError(
+ "Unsupported output stream or logging level " +
+ str(output_stream) + ". Supported streams are sys.stdout, "
+ "sys.stderr, tf.logging.info, "
+ "tf.logging.warning, tf.logging.error")
+
+ # If we are only printing a single string scalar, there is no need to format
+ if (len(inputs) == 1 and tensor_util.is_tensor(inputs[0])
+ and (not isinstance(inputs[0], sparse_tensor.SparseTensor))
+ and inputs[0].shape and (inputs[0].dtype == dtypes.string)):
+ formatted_string = inputs[0]
+ # Otherwise, we construct an appropriate template for the tensors we are
+ # printing, and format the template using those tensors.
+ else:
+ # For each input to this print function, we extract any nested tensors,
+ # and construct an appropriate template to format representing the
+ # printed input.
+ templates = []
+ tensors = []
+ tensor_free_structure = nest.map_structure(
+ lambda x: "" if tensor_util.is_tensor(x) else x,
+ inputs)
+ tensor_free_template = " ".join(pprint.pformat(x)
+ for x in tensor_free_structure)
+ placeholder = _generate_placeholder_string(tensor_free_template)
+
+ for input_ in inputs:
+ placeholders = []
+ # Use the nest utilities to flatten & process any nested elements in this
+ # input. The placeholder for a tensor in the template should be the
+ # placeholder string, and the placeholder for a non-tensor can just be
+ # the printed value of the non-tensor itself.
+ for x in nest.flatten(input_):
+ # support sparse tensors
+ if isinstance(x, sparse_tensor.SparseTensor):
+ tensors.extend([x.indices, x.values, x.dense_shape])
+ placeholders.append(
+ "SparseTensor(indices={}, values={}, shape={})".format(
+ placeholder, placeholder, placeholder)
+ )
+ elif tensor_util.is_tensor(x):
+ tensors.append(x)
+ placeholders.append(placeholder)
+ else:
+ placeholders.append(x)
+
+ if isinstance(input_, six.string_types):
+ # If the current input to format/print is a normal string, that string
+ # can act as the template.
+ cur_template = input_
+ else:
+ # We pack the placeholders into a data structure that matches the
+ # input data structure format, then format that data structure
+ # into a string template.
+ #
+ # NOTE: We must use pprint.pformat here for building the template for
+ # unordered data structures such as `dict`, because `str` doesn't
+ # guarantee orderings, while pprint prints in sorted order. pprint
+ # will match the ordering of `nest.flatten`.
+ # This even works when nest.flatten reorders OrderedDicts, because
+ # pprint is printing *after* the OrderedDicts have been reordered.
+ cur_template = pprint.pformat(
+ nest.pack_sequence_as(input_, placeholders))
+ templates.append(cur_template)
+
+ # We join the templates for the various inputs into a single larger
+ # template. We also remove all quotes surrounding the placeholders, so that
+ # the formatted/printed output will not contain quotes around tensors.
+ # (example of where these quotes might appear: if we have added a
+ # placeholder string into a list, then pretty-formatted that list)
+ template = " ".join(templates)
+ template = template.replace("'" + placeholder + "'", placeholder)
+ formatted_string = string_ops.string_format(
+ inputs=tensors, template=template, placeholder=placeholder,
+ summarize=summarize,
+ name=format_name)
+
+ return gen_logging_ops.print_v2(formatted_string,
+ output_stream=output_stream_string,
+ name=name)
+# pylint: enable=g-doc-args
@ops.RegisterGradient("Print")
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
index 561a341cf3..5443699ddd 100644
--- a/tensorflow/python/ops/lookup_ops.py
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -422,7 +422,7 @@ class TextFileInitializer(TableInitializerBase):
* `palmer -> 30`
```python
- table = tf.contrib.lookup.HashTable(tf.contrib.lookup.TextFileInitializer(
+ table = tf.lookup.HashTable(tf.lookup.TextFileInitializer(
"test.txt", tf.string, 0, tf.int64, 1, delimiter=" "), -1)
...
table.init.run()
@@ -435,9 +435,9 @@ class TextFileInitializer(TableInitializerBase):
* `palmer 30 -> 2`
```python
- table = tf.contrib.lookup.HashTable(tf.contrib.lookup.TextFileInitializer(
- "test.txt", tf.string, tf.contrib.lookup.TextFileIndex.WHOLE_LINE,
- tf.int64, tf.contrib.lookup.TextFileIndex.LINE_NUMBER, delimiter=" "), -1)
+ table = tf.lookup.HashTable(tf.lookup.TextFileInitializer(
+ "test.txt", tf.string, tf.lookup.TextFileIndex.WHOLE_LINE,
+ tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER, delimiter=" "), -1)
...
table.init.run()
```
@@ -953,7 +953,7 @@ def index_table_from_file(vocabulary_file=None,
```python
features = tf.constant(["emerson", "lake", "and", "palmer"])
- table = tf.contrib.lookup.index_table_from_file(
+ table = tf.lookup.index_table_from_file(
vocabulary_file="test.txt", num_oov_buckets=1)
ids = table.lookup(features)
...
@@ -1054,21 +1054,21 @@ def index_table_from_tensor(vocabulary_list,
Any lookup of an out-of-vocabulary token will return a bucket ID based on its
hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
- `default_value`.
- The bucket ID range is `[mapping size, mapping size + num_oov_buckets - 1]`.
+ `default_value`. The bucket ID range is
+ `[vocabulary list size, vocabulary list size + num_oov_buckets - 1]`.
The underlying table must be initialized by calling
`tf.tables_initializer.run()` or `table.init.run()` once.
- Elements in `mapping` cannot have duplicates, otherwise when executing the
- table initializer op, it will throw a `FailedPreconditionError`.
+ Elements in `vocabulary_list` cannot have duplicates, otherwise when executing
+ the table initializer op, it will throw a `FailedPreconditionError`.
Sample Usages:
```python
vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
- table = tf.contrib.lookup.index_table_from_tensor(
- mapping=vocabulary_list, num_oov_buckets=1, default_value=-1)
+ table = tf.lookup.index_table_from_tensor(
+ vocabulary_list=vocabulary_list, num_oov_buckets=1, default_value=-1)
features = tf.constant(["emerson", "lake", "and", "palmer"])
ids = table.lookup(features)
...
@@ -1093,7 +1093,7 @@ def index_table_from_tensor(vocabulary_list,
The lookup table to map an input `Tensor` to index `int64` `Tensor`.
Raises:
- ValueError: If `mapping` is invalid.
+ ValueError: If `vocabulary_list` is invalid.
ValueError: If `num_oov_buckets` is negative.
"""
if vocabulary_list is None:
@@ -1185,7 +1185,7 @@ def index_to_string_table_from_file(vocabulary_file,
```python
indices = tf.constant([1, 5], tf.int64)
- table = tf.contrib.lookup.index_to_string_table_from_file(
+ table = tf.lookup.index_to_string_table_from_file(
vocabulary_file="test.txt", default_value="UNKNOWN")
values = table.lookup(indices)
...
@@ -1250,25 +1250,25 @@ def index_to_string_table_from_tensor(vocabulary_list,
"""Returns a lookup table that maps a `Tensor` of indices into strings.
This operation constructs a lookup table to map int64 indices into string
- values. The mapping is initialized from a string `mapping` 1-D `Tensor` where
- each element is a value and the corresponding index within the tensor is the
- key.
+ values. The mapping is initialized from a string `vocabulary_list` 1-D
+ `Tensor` where each element is a value and the corresponding index within the
+ tensor is the key.
- Any input which does not have a corresponding index in 'mapping'
+ Any input which does not have a corresponding index in 'vocabulary_list'
(an out-of-vocabulary entry) is assigned the `default_value`
The underlying table must be initialized by calling
`tf.tables_initializer.run()` or `table.init.run()` once.
- Elements in `mapping` cannot have duplicates, otherwise when executing the
- table initializer op, it will throw a `FailedPreconditionError`.
+ Elements in `vocabulary_list` cannot have duplicates, otherwise when executing
+ the table initializer op, it will throw a `FailedPreconditionError`.
Sample Usages:
```python
vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
indices = tf.constant([1, 5], tf.int64)
- table = tf.contrib.lookup.index_to_string_table_from_tensor(
+ table = tf.lookup.index_to_string_table_from_tensor(
vocabulary_list, default_value="UNKNOWN")
values = table.lookup(indices)
...
diff --git a/tensorflow/python/ops/losses/util_test.py b/tensorflow/python/ops/losses/util_test.py
index 7fa7a41fca..df2e60e2e4 100644
--- a/tensorflow/python/ops/losses/util_test.py
+++ b/tensorflow/python/ops/losses/util_test.py
@@ -28,7 +28,7 @@ class LossesUtilTest(test.TestCase):
def testGetRegularizationLoss(self):
# Empty regularization collection should evaluate to 0.0.
- with self.test_session():
+ with self.cached_session():
self.assertEqual(0.0, util.get_regularization_loss().eval())
# Loss should sum.
@@ -36,14 +36,14 @@ class LossesUtilTest(test.TestCase):
ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(2.0))
ops.add_to_collection(
ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(3.0))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(5.0, util.get_regularization_loss().eval())
# Check scope capture mechanism.
with ops.name_scope('scope1'):
ops.add_to_collection(
ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(-1.0))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(-1.0, util.get_regularization_loss('scope1').eval())
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 8e11c4bce1..35278d9680 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -516,6 +516,40 @@ def _Log1pGrad(op, grad):
return grad * math_ops.reciprocal(1 + x)
+@ops.RegisterGradient("Xlogy")
+def _XLogyGrad(op, grad):
+ """Returns gradient of xlogy(x, y) with respect to x and y."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
+ with ops.control_dependencies([grad]):
+ not_zero_x = math_ops.cast(
+ math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
+ partial_x = gen_math_ops.xlogy(not_zero_x, y)
+ partial_y = gen_math_ops.xdivy(x, y)
+ return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
+
+
+@ops.RegisterGradient("Xdivy")
+def _XDivyGrad(op, grad):
+ """Returns gradient of xdivy(x, y) with respect to x and y."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
+ with ops.control_dependencies([grad]):
+ not_zero_x = math_ops.cast(
+ math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
+ partial_x = gen_math_ops.xdivy(not_zero_x, y)
+ partial_y = gen_math_ops.xdivy(math_ops.negative(x), y**2)
+ return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
+
+
@ops.RegisterGradient("Sinh")
def _SinhGrad(op, grad):
"""Returns grad * cosh(x)."""
diff --git a/tensorflow/python/ops/math_grad_test.py b/tensorflow/python/ops/math_grad_test.py
index 7110e0958c..9cfb050942 100644
--- a/tensorflow/python/ops/math_grad_test.py
+++ b/tensorflow/python/ops/math_grad_test.py
@@ -256,5 +256,93 @@ class DivNoNanGradientTest(test.TestCase):
self.assertAllClose(dy.eval(), np.zeros(y.shape.as_list()))
+class XlogyTest(test.TestCase):
+
+ def _xlogy_gradients(self, x, y):
+ xlogy_xgrad = self.evaluate(gradients.gradients(math_ops.xlogy(x, y), x)[0])
+ xlogy_ygrad = self.evaluate(gradients.gradients(math_ops.xlogy(x, y), y)[0])
+ return xlogy_xgrad, xlogy_ygrad
+
+ def testNonZeroValuesGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0.1, dtype=dtype)
+ y = constant_op.constant(3.1, dtype=dtype)
+ xlogy_xgrad, xlogy_ygrad = self._xlogy_gradients(x, y)
+ xlogy_expected_xgrad = self.evaluate(math_ops.log(y))
+ xlogy_expected_ygrad = self.evaluate(x / y)
+ self.assertAllClose(xlogy_expected_xgrad, xlogy_xgrad)
+ self.assertAllClose(xlogy_expected_ygrad, xlogy_ygrad)
+
+ def testZeroXGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0., dtype=dtype)
+ y = constant_op.constant(3.1, dtype=dtype)
+ xlogy_xgrad, xlogy_ygrad = self._xlogy_gradients(x, y)
+ zero = self.evaluate(x)
+ self.assertAllClose(zero, xlogy_xgrad)
+ self.assertAllClose(zero, xlogy_ygrad)
+
+ def testZeroYGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0.1, dtype=dtype)
+ y = constant_op.constant(0., dtype=dtype)
+ xlogy_xgrad, xlogy_ygrad = self._xlogy_gradients(x, y)
+ self.assertAllClose(-np.inf, xlogy_xgrad)
+ self.assertAllClose(np.inf, xlogy_ygrad)
+
+ def testZeroXYGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0., dtype=dtype)
+ y = constant_op.constant(0., dtype=dtype)
+ xlogy_xgrad, xlogy_ygrad = self._xlogy_gradients(x, y)
+ zero = self.evaluate(x)
+ self.assertAllClose(zero, xlogy_xgrad)
+ self.assertAllClose(zero, xlogy_ygrad)
+
+
+class XdivyTest(test.TestCase):
+
+ def _xdivy_gradients(self, x, y):
+ xdivy_xgrad = self.evaluate(gradients.gradients(math_ops.xdivy(x, y), x)[0])
+ xdivy_ygrad = self.evaluate(gradients.gradients(math_ops.xdivy(x, y), y)[0])
+ return xdivy_xgrad, xdivy_ygrad
+
+ def testNonZeroValuesGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0.1, dtype=dtype)
+ y = constant_op.constant(3.1, dtype=dtype)
+ xdivy_xgrad, xdivy_ygrad = self._xdivy_gradients(x, y)
+ xdivy_expected_xgrad = self.evaluate(1 / y)
+ xdivy_expected_ygrad = self.evaluate(-x / y**2)
+ self.assertAllClose(xdivy_expected_xgrad, xdivy_xgrad)
+ self.assertAllClose(xdivy_expected_ygrad, xdivy_ygrad)
+
+ def testZeroXGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0., dtype=dtype)
+ y = constant_op.constant(3.1, dtype=dtype)
+ xdivy_xgrad, xdivy_ygrad = self._xdivy_gradients(x, y)
+ zero = self.evaluate(x)
+ self.assertAllClose(zero, xdivy_xgrad)
+ self.assertAllClose(zero, xdivy_ygrad)
+
+ def testZeroYGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0.1, dtype=dtype)
+ y = constant_op.constant(0., dtype=dtype)
+ xdivy_xgrad, xdivy_ygrad = self._xdivy_gradients(x, y)
+ self.assertAllClose(np.inf, xdivy_xgrad)
+ self.assertAllClose(-np.inf, xdivy_ygrad)
+
+ def testZeroXYGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0., dtype=dtype)
+ y = constant_op.constant(0., dtype=dtype)
+ xdivy_xgrad, xdivy_ygrad = self._xdivy_gradients(x, y)
+ zero = self.evaluate(x)
+ self.assertAllClose(zero, xdivy_xgrad)
+ self.assertAllClose(zero, xdivy_ygrad)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index acd5a32e82..f57abf6704 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -2898,21 +2898,23 @@ def tensordot(a, b, axes, name=None):
shape_a = a.get_shape().as_list()
axes = [i if i >= 0 else i + len(shape_a) for i in axes]
free = [i for i in xrange(len(shape_a)) if i not in axes]
- free_dims_static = [shape_a[i] for i in free]
+ axes_dims = [shape_a[i] for i in axes]
+ free_dims = [shape_a[i] for i in free]
+ free_dims_static = free_dims
+ axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
+ free = ops.convert_to_tensor(free, dtype=dtypes.int32, name="free")
+ shape_a = array_ops.shape(a)
else:
free_dims_static = None
- shape_a = array_ops.shape(a)
- rank_a = array_ops.rank(a)
- axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
- axes = cast(axes >= 0, dtypes.int32) * axes + cast(
- axes < 0, dtypes.int32) * (
- axes + rank_a)
- free, _ = array_ops.setdiff1d(range(rank_a), axes)
+ shape_a = array_ops.shape(a)
+ rank_a = array_ops.rank(a)
+ axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
+ axes = array_ops.where(axes >= 0, axes, axes + rank_a)
+ free, _ = array_ops.setdiff1d(range(rank_a), axes)
free_dims = array_ops.gather(shape_a, free)
axes_dims = array_ops.gather(shape_a, axes)
prod_free_dims = reduce_prod(free_dims)
prod_axes_dims = reduce_prod(axes_dims)
- perm = array_ops.concat([axes_dims, free_dims], 0)
if flipped:
perm = array_ops.concat([axes, free], 0)
new_shape = array_ops.stack([prod_axes_dims, prod_free_dims])
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 1b01d1d37f..f051850d92 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -21,6 +21,7 @@ import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@@ -488,5 +489,75 @@ class DivNoNanTest(test_util.TensorFlowTestCase):
self.assertAllEqual(tf_result, np_result)
+class XlogyTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXlogyNoZero(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant([[0.1, 0.2, 3.5], [-2., -5., 30.]], dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [3.1, 4., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xlogy = self.evaluate(math_ops.xlogy(x, y))
+ xtimeslogy = self.evaluate(x * math_ops.log(y))
+ self.assertAllClose(xlogy, xtimeslogy)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXlogyWithZero(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(np.zeros((2, 3)), dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [0., 1., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xlogy_tf_np = self.evaluate(math_ops.xlogy(x, y))
+ zeros_np = self.evaluate(array_ops.zeros_like(y))
+ self.assertAllClose(xlogy_tf_np, zeros_np)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXlogyWithZeroBroadcast(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant([[0.], [1.]], dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [0., 1., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xlogy_tf_np = self.evaluate(math_ops.xlogy(x, y))
+ zeros_np = self.evaluate(array_ops.zeros_like(y[0]))
+ xtimes_logy = self.evaluate(math_ops.log(y[1]))
+ self.assertAllClose(zeros_np, xlogy_tf_np[0])
+ self.assertAllClose(xtimes_logy, xlogy_tf_np[1])
+
+
+class XdivyTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXdivyNoZero(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant([[0.1, 0.2, 3.5], [-2., -5., 30.]], dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [3.1, 4., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xdivy = self.evaluate(math_ops.xdivy(x, y))
+ x_over_y = self.evaluate(x / y)
+ self.assertAllClose(xdivy, x_over_y)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXdivyWithZero(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(np.zeros((2, 3)), dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [0., 1., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xdivy_tf_np = self.evaluate(math_ops.xdivy(x, y))
+ zeros_np = self.evaluate(array_ops.zeros_like(y))
+ self.assertAllClose(xdivy_tf_np, zeros_np)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXdivyWithZeroBroadcast(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant([[0.], [1.]], dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [0., 1., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xdivy_tf_np = self.evaluate(math_ops.xdivy(x, y))
+ zeros_np = self.evaluate(array_ops.zeros_like(y[0]))
+ x_over_y = self.evaluate(1 / y[1])
+ self.assertAllClose(zeros_np, xdivy_tf_np[0])
+ self.assertAllClose(x_over_y, xdivy_tf_np[1])
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/ops/matmul_benchmark.py b/tensorflow/python/ops/matmul_benchmark.py
index 6e5fe74290..138149e63d 100644
--- a/tensorflow/python/ops/matmul_benchmark.py
+++ b/tensorflow/python/ops/matmul_benchmark.py
@@ -49,13 +49,13 @@ def build_graph(device, n, m, k, transpose_a, transpose_b, dtype):
"""
with ops.device('%s' % device):
if not transpose_a:
- x = variables.Variable(random_ops.random_uniform([n, m], dtype=dtype))
+ x = variables.VariableV1(random_ops.random_uniform([n, m], dtype=dtype))
else:
- x = variables.Variable(random_ops.random_uniform([m, n], dtype=dtype))
+ x = variables.VariableV1(random_ops.random_uniform([m, n], dtype=dtype))
if not transpose_b:
- y = variables.Variable(random_ops.random_uniform([m, k], dtype=dtype))
+ y = variables.VariableV1(random_ops.random_uniform([m, k], dtype=dtype))
else:
- y = variables.Variable(random_ops.random_uniform([k, m], dtype=dtype))
+ y = variables.VariableV1(random_ops.random_uniform([k, m], dtype=dtype))
z = math_ops.matmul(x, y, transpose_a=transpose_a, transpose_b=transpose_b)
return control_flow_ops.group(z)
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 2526e6fee2..9ef177e97b 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -22,7 +22,6 @@ import numbers
import numpy as np
-from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_util
@@ -1670,47 +1669,24 @@ def _softmax(logits, compute_op, dim=-1, name=None):
shape = logits.get_shape()
is_last_dim = (dim is -1) or (dim == shape.ndims - 1)
- # TODO(phawkins): remove after 2018/8/27 and simplify this code.
- softmax_accepts_r1_or_greater = compat.forward_compatible(2018, 8, 27)
- reshape_required = (not softmax_accepts_r1_or_greater) and shape.ndims != 2
if is_last_dim:
- if reshape_required:
- # If dim is the last dimension, simply reshape the logits to a matrix and
- # apply the internal softmax.
- input_shape = array_ops.shape(logits)
- logits = _flatten_outer_dims(logits)
- output = compute_op(logits)
- output = array_ops.reshape(output, input_shape, name=name)
- return output
return compute_op(logits, name=name)
- # If dim is not the last dimension, we have to do a reshape and transpose so
- # that we can still perform softmax on its last dimension.
+ # If dim is not the last dimension, we have to do a transpose so that we can
+ # still perform softmax on its last dimension.
# Swap logits' dimension of dim and its last dimension.
input_rank = array_ops.rank(logits)
dim_axis = dim % shape.ndims
logits = _swap_axis(logits, dim_axis, math_ops.subtract(input_rank, 1))
- shape_after_swap = array_ops.shape(logits)
- if reshape_required:
- # Reshape logits into a matrix.
- logits = _flatten_outer_dims(logits)
-
- # Do the actual softmax on its last dimension.
- output = compute_op(logits)
-
- # Transform back the output tensor.
- output = array_ops.reshape(output, shape_after_swap)
- else:
- # Do the actual softmax on its last dimension.
- output = compute_op(logits)
+ # Do the actual softmax on its last dimension.
+ output = compute_op(logits)
output = _swap_axis(
output, dim_axis, math_ops.subtract(input_rank, 1), name=name)
- # Make shape inference work since reshape and transpose may erase its static
- # shape.
+ # Make shape inference work since transpose may erase its static shape.
output.set_shape(shape)
return output
diff --git a/tensorflow/python/ops/parallel_for/BUILD b/tensorflow/python/ops/parallel_for/BUILD
index 015181af47..07fc9433a2 100644
--- a/tensorflow/python/ops/parallel_for/BUILD
+++ b/tensorflow/python/ops/parallel_for/BUILD
@@ -123,6 +123,8 @@ cuda_py_test(
"//third_party/py/numpy",
"//tensorflow/python:layers",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:functional_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python/ops/losses",
],
diff --git a/tensorflow/python/ops/parallel_for/gradients.py b/tensorflow/python/ops/parallel_for/gradients.py
index 460de0a97f..1f026b3660 100644
--- a/tensorflow/python/ops/parallel_for/gradients.py
+++ b/tensorflow/python/ops/parallel_for/gradients.py
@@ -42,6 +42,7 @@ def jacobian(output, inputs, use_pfor=True):
[y_1, ..., y_n, x_1, ..., x_m].
"""
flat_inputs = nest.flatten(inputs)
+ output_tensor_shape = output.shape
output_shape = array_ops.shape(output)
output = array_ops.reshape(output, [-1])
@@ -65,6 +66,7 @@ def jacobian(output, inputs, use_pfor=True):
new_shape = array_ops.concat(
[output_shape, array_ops.shape(out)[1:]], axis=0)
out = array_ops.reshape(out, new_shape)
+ out.set_shape(output_tensor_shape.concatenate(flat_inputs[i].shape))
pfor_outputs[i] = out
return nest.pack_sequence_as(inputs, pfor_outputs)
diff --git a/tensorflow/python/ops/parallel_for/gradients_test.py b/tensorflow/python/ops/parallel_for/gradients_test.py
index 628c6764cd..5467f55af6 100644
--- a/tensorflow/python/ops/parallel_for/gradients_test.py
+++ b/tensorflow/python/ops/parallel_for/gradients_test.py
@@ -32,6 +32,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.layers import layers as tf_layers
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops as tf_control_flow_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients as gradient_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
@@ -355,6 +357,30 @@ class GradientsTest(test.TestCase):
self.run_and_assert_equal(answer, jacobian_pfor)
self.run_and_assert_equal(answer, jacobian_while)
+ def test_jacobian_scan_shape(self):
+ # Shape x: [3, 4]
+ x = random_ops.random_uniform([3, 4])
+ elems = random_ops.random_uniform([6])
+ # Shape y: [6, 3, 4]
+ y = functional_ops.scan(lambda a, e: a + e, elems, initializer=x)
+ jacobian = gradients.jacobian(y, x)
+
+ expected_shape = [6, 3, 4, 3, 4]
+ self.assertAllEqual(expected_shape, jacobian.shape.as_list())
+
+ def test_jacobian_while_loop_shape(self):
+ # Shape x: [3, 4]
+ x = random_ops.random_uniform([3, 4])
+ _, y = tf_control_flow_ops.while_loop(lambda i, a: i > 5.,
+ lambda i, a: (i + 1, a + i),
+ (constant_op.constant(0.), x))
+ # Shape y: [2, 3]
+ y = y[:2, :3]
+ jacobian = gradients.jacobian(y, x)
+
+ expected_shape = [2, 3, 3, 4]
+ self.assertAllEqual(expected_shape, jacobian.shape.as_list())
+
def test_jacobian_unknown_shape(self):
with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32, shape=[None, None])
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index bb8da3162a..b3e03a0135 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -981,9 +981,10 @@ def parse_sequence_example(serialized,
name: A name for this operation (optional).
Returns:
- A tuple of two `dict`s, each mapping keys to `Tensor`s and `SparseTensor`s.
- The first dict contains the context key/values.
- The second dict contains the feature_list key/values.
+ A tuple of three `dict`s, each mapping keys to `Tensor`s and
+ `SparseTensor`s. The first dict contains the context key/values,
+ the second dict contains the feature_list key/values, and the final dict
+ contains the lengths of any dense feature_list features.
Raises:
ValueError: if any feature is invalid.
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 55c2eb5fa4..4a126e9d7a 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -48,14 +48,14 @@ def get_resource_handle_data(graph_op):
assert ops._USE_C_SHAPES # pylint: disable=protected-access
assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck
- handle_data = pywrap_tensorflow.GetResourceHandleShapeAndType(
+ handle_data = pywrap_tensorflow.GetHandleShapeAndType(
graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access
return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
compat.as_bytes(handle_data))
-def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
+def eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
"""Creates a variable handle with information to do shape inference."""
container = ops.get_default_graph()._container # pylint: disable=protected-access
if container is None:
@@ -397,61 +397,33 @@ class ResourceVariable(variables.RefVariable):
# When in eager mode use a uid for the shared_name, to prevent
# accidental sharing.
shared_name = "%s_%d" % (handle_name, ops.uid())
- if init_from_fn:
- # Use attr_scope and device(None) to simulate the behavior of
- # colocate_with when the variable we want to colocate with doesn't
- # yet exist.
- if self._in_graph_mode:
- attr = attr_value_pb2.AttrValue(
- list=attr_value_pb2.AttrValue.ListValue(
- s=[compat.as_bytes("loc:@%s" % handle_name)]))
- with ops.get_default_graph()._attr_scope({"_class": attr}):
- with ops.name_scope("Initializer"), ops.device(None):
- initial_value = ops.convert_to_tensor(
- initial_value(), name="initial_value", dtype=dtype)
- self._handle = _eager_safe_variable_handle(
- shape=initial_value.get_shape(),
- dtype=initial_value.dtype.base_dtype,
- shared_name=shared_name,
- name=name,
- graph_mode=self._in_graph_mode)
- self._shape = initial_value.get_shape()
- else:
- initial_value = initial_value()
- with ops.name_scope("Initializer"):
- initial_value = ops.convert_to_tensor(
- initial_value, name="initial_value", dtype=dtype)
- self._handle = _eager_safe_variable_handle(
- shape=initial_value.get_shape(),
- dtype=initial_value.dtype.base_dtype,
- shared_name=shared_name,
- name=name,
- graph_mode=False)
- self._shape = initial_value.get_shape()
- # pylint: enable=protected-access
-
- # Or get the initial value from a Tensor or Python object.
- else:
- with ops.name_scope("Initializer"):
+ # Use attr_scope and device(None) to simulate the behavior of
+ # colocate_with when the variable we want to colocate with doesn't
+ # yet exist.
+ attr = attr_value_pb2.AttrValue(
+ list=attr_value_pb2.AttrValue.ListValue(
+ s=[compat.as_bytes("loc:@%s" % handle_name)]))
+ with ops.get_default_graph()._attr_scope({"_class": attr}):
+ with ops.name_scope("Initializer"), ops.device(None):
initial_value = ops.convert_to_tensor(
- initial_value, name="initial_value", dtype=dtype)
- # pylint: disable=protected-access
- if (self._in_graph_mode and initial_value is not None and
- initial_value.op._get_control_flow_context() is not None):
- raise ValueError(
- "Initializer for variable %s is from inside a control-flow "
- "construct, such as a loop or conditional. When creating a "
- "variable inside a loop or conditional, use a lambda as the "
- "initializer." % name)
- # pylint: enable=protected-access
- self._handle = _eager_safe_variable_handle(
+ initial_value() if init_from_fn else initial_value,
+ name="initial_value", dtype=dtype)
+ self._handle = eager_safe_variable_handle(
shape=initial_value.get_shape(),
dtype=initial_value.dtype.base_dtype,
shared_name=shared_name,
name=name,
graph_mode=self._in_graph_mode)
- self._shape = initial_value.get_shape()
-
+ self._shape = initial_value.shape
+ # pylint: disable=protected-access
+ if (self._in_graph_mode and initial_value is not None and
+ initial_value.op._get_control_flow_context() is not None):
+ raise ValueError(
+ "Initializer for variable %s is from inside a control-flow "
+ "construct, such as a loop or conditional. When creating a "
+ "variable inside a loop or conditional, use a lambda as the "
+ "initializer." % name)
+ # pylint: enable=protected-access
self._unique_id = shared_name
self._initial_value = initial_value if self._in_graph_mode else None
self._handle_name = handle_name + ":0"
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 43cca1a498..c2751e529a 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -611,7 +611,7 @@ class LSTMStateTuple(_LSTMStateTuple):
# TODO(scottzhu): Stop exporting this class in TF 2.0.
@tf_export("nn.rnn_cell.BasicLSTMCell")
class BasicLSTMCell(LayerRNNCell):
- """DEPRECATED: Please use @{tf.nn.rnn_cell.LSTMCell} instead.
+ """DEPRECATED: Please use `tf.nn.rnn_cell.LSTMCell` instead.
Basic LSTM recurrent network cell.
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index b2c6937368..046a48d192 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -29,16 +29,19 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.util import compat as util_compat
# go/tf-wildcard-import
# pylint: disable=wildcard-import
+# pylint: disable=g-bad-import-order
from tensorflow.python.ops.gen_string_ops import *
+from tensorflow.python.util import compat as util_compat
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
+# pylint: enable=g-bad-import-order
# pylint: enable=wildcard-import
@@ -103,6 +106,87 @@ def regex_replace(source, pattern, rewrite, replace_global=True):
rewrite=rewrite, replace_global=replace_global)
+@tf_export("strings.format")
+def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
+ r"""Formats a string template using a list of tensors.
+
+ Formats a string template using a list of tensors, abbreviating tensors by
+ only printing the first and last `summarize` elements of each dimension
+ (recursively). If formatting only one tensor into a template, the tensor does
+ not have to be wrapped in a list.
+
+ Example:
+ Formatting a single-tensor template:
+ ```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor = tf.range(10)
+ formatted = tf.strings.format("tensor: {}, suffix", tensor)
+ out = sess.run(formatted)
+ expected = "tensor: [0 1 2 ... 7 8 9], suffix"
+
+ assert(out.decode() == expected)
+ ```
+
+ Formatting a multi-tensor template:
+ ```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor_one = tf.reshape(tf.range(100), [10, 10])
+ tensor_two = tf.range(10)
+ formatted = tf.strings.format("first: {}, second: {}, suffix",
+ (tensor_one, tensor_two))
+
+ out = sess.run(formatted)
+ expected = ("first: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]], second: [0 1 2 ... 7 8 9], suffix")
+
+ assert(out.decode() == expected)
+ ```
+
+ Args:
+ template: A string template to format tensor values into.
+ inputs: A list of `Tensor` objects, or a single Tensor.
+ The list of tensors to format into the template string. If a solitary
+ tensor is passed in, the input tensor will automatically be wrapped as a
+ list.
+ placeholder: An optional `string`. Defaults to `{}`.
+ At each placeholder occurring in the template, a subsequent tensor
+ will be inserted.
+ summarize: An optional `int`. Defaults to `3`.
+ When formatting the tensors, show the first and last `summarize`
+ entries of each tensor dimension (recursively). If set to -1, all
+ elements of the tensor will be shown.
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar `Tensor` of type `string`.
+
+ Raises:
+ ValueError: if the number of placeholders does not match the number of
+ inputs.
+ """
+ # If there is only one tensor to format, we will automatically wrap it in a
+ # list to simplify the user experience
+ if tensor_util.is_tensor(inputs):
+ inputs = [inputs]
+ if template.count(placeholder) != len(inputs):
+ raise ValueError("%s placeholder(s) in template does not match %s tensor(s)"
+ " provided as input" % (template.count(placeholder),
+ len(inputs)))
+
+ return gen_string_ops.string_format(inputs,
+ template=template,
+ placeholder=placeholder,
+ summarize=summarize,
+ name=name)
+
+
@tf_export("string_split")
def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=invalid-name
"""Split elements of `source` based on `delimiter` into a `SparseTensor`.
@@ -246,6 +330,17 @@ def reduce_join(inputs, axis=None,
reduce_join.__doc__ = deprecation.rewrite_argument_docstring(
gen_string_ops.reduce_join.__doc__, "reduction_indices", "axis")
+
+# This wrapper provides backwards compatibility for code that predates the
+# unit argument and that passed 'name' as a positional argument.
+@tf_export("strings.length")
+def string_length(input, name=None, unit="BYTE"):
+ return gen_string_ops.string_length(input, unit=unit, name=name)
+
+
+string_length.__doc__ = gen_string_ops.string_length.__doc__
+
+
ops.NotDifferentiable("RegexReplace")
ops.NotDifferentiable("StringToHashBucket")
ops.NotDifferentiable("StringToHashBucketFast")
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index a43676cd70..af5c7d4050 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -198,7 +198,7 @@ VariableSynchronization = variables.VariableSynchronization # pylint: disable=i
VariableAggregation = variables.VariableAggregation # pylint: disable=invalid-name
AUTO_REUSE = _ReuseMode.AUTO_REUSE
-tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE")
+tf_export(v1=["AUTO_REUSE"]).export_constant(__name__, "AUTO_REUSE")
AUTO_REUSE.__doc__ = """
When passed in as the value for the `reuse` flag, AUTO_REUSE indicates that
get_variable() should create the requested variable if it doesn't exist or, if
@@ -515,8 +515,10 @@ class _VariableStore(object):
"synchronization": synchronization,
"aggregation": aggregation,
}
- # `fn_args` can handle functions, `functools.partial`, `lambda`.
- if "constraint" in function_utils.fn_args(custom_getter):
+ # `fn_args` and `has_kwargs` can handle functions, `functools.partial`,
+ # `lambda`.
+ if ("constraint" in function_utils.fn_args(custom_getter) or
+ function_utils.has_kwargs(custom_getter)):
custom_getter_kwargs["constraint"] = constraint
return custom_getter(**custom_getter_kwargs)
else:
@@ -906,7 +908,7 @@ class _VariableStore(object):
if use_resource is None:
# Set the default value if unspecified.
use_resource = _DEFAULT_USE_RESOURCE
- v = variable(
+ v = variables.VariableV1(
initial_value=init_val,
name=name,
trainable=trainable,
@@ -992,7 +994,7 @@ def no_regularizer(_):
# TODO(alive): support caching devices and partitioned variables in Eager mode.
-@tf_export("VariableScope")
+@tf_export(v1=["VariableScope"])
class VariableScope(object):
"""Variable scope object to carry defaults to provide to `get_variable`.
@@ -1340,7 +1342,7 @@ def get_variable_scope_store():
return scope_store
-@tf_export("get_variable_scope")
+@tf_export(v1=["get_variable_scope"])
def get_variable_scope():
"""Returns the current variable scope."""
return get_variable_scope_store().current_scope
@@ -1449,7 +1451,7 @@ class EagerVariableStore(object):
# The argument list for get_variable must match arguments to get_local_variable.
# So, if you are updating the arguments, also update arguments to
# get_local_variable below.
-@tf_export("get_variable")
+@tf_export(v1=["get_variable"])
def get_variable(name,
shape=None,
dtype=None,
@@ -1594,7 +1596,7 @@ get_variable.__doc__ = get_variable_or_local_docstring % (
# The argument list for get_local_variable must match arguments to get_variable.
# So, if you are updating the arguments, also update arguments to get_variable.
-@tf_export("get_local_variable")
+@tf_export(v1=["get_local_variable"])
def get_local_variable( # pylint: disable=missing-docstring
name,
shape=None,
@@ -1939,7 +1941,7 @@ def _get_unique_variable_scope(prefix):
# Named like a function for backwards compatibility with the
# @tf_contextlib.contextmanager version, which was switched to a class to avoid
# some object creation overhead.
-@tf_export("variable_scope") # pylint: disable=invalid-name
+@tf_export(v1=["variable_scope"]) # pylint: disable=invalid-name
class variable_scope(object):
"""A context manager for defining ops that creates variables (layers).
@@ -2320,7 +2322,7 @@ class variable_scope(object):
# pylint: disable=g-doc-return-or-yield
-@tf_export("variable_op_scope")
+@tf_export(v1=["variable_op_scope"])
@tf_contextlib.contextmanager
def variable_op_scope(values,
name_or_scope,
@@ -2441,7 +2443,33 @@ def default_variable_creator(next_creator=None, **kwargs):
expected_shape=expected_shape, import_scope=import_scope)
+def default_variable_creator_v2(next_creator=None, **kwargs):
+ """Default variable creator."""
+ assert next_creator is None
+ initial_value = kwargs.get("initial_value", None)
+ trainable = kwargs.get("trainable", None)
+ validate_shape = kwargs.get("validate_shape", True)
+ caching_device = kwargs.get("caching_device", None)
+ name = kwargs.get("name", None)
+ variable_def = kwargs.get("variable_def", None)
+ dtype = kwargs.get("dtype", None)
+ import_scope = kwargs.get("import_scope", None)
+ constraint = kwargs.get("constraint", None)
+
+ # Set trainable value based on synchronization value.
+ synchronization = kwargs.get("synchronization", VariableSynchronization.AUTO)
+ trainable = _get_trainable_value(
+ synchronization=synchronization, trainable=trainable)
+
+ return resource_variable_ops.ResourceVariable(
+ initial_value=initial_value, trainable=trainable,
+ validate_shape=validate_shape, caching_device=caching_device,
+ name=name, dtype=dtype, constraint=constraint, variable_def=variable_def,
+ import_scope=import_scope)
+
+
variables.default_variable_creator = default_variable_creator
+variables.default_variable_creator_v2 = default_variable_creator_v2
def _make_getter(captured_getter, captured_previous):
@@ -2450,11 +2478,12 @@ def _make_getter(captured_getter, captured_previous):
# TODO(apassos) remove forwarding symbol
-variable = variables.Variable
+variable = variables.VariableV1
+@tf_export(v1=["variable_creator_scope"])
@tf_contextlib.contextmanager
-def variable_creator_scope(variable_creator):
+def variable_creator_scope_v1(variable_creator):
"""Scope which defines a variable creation function to be used by variable().
variable_creator is expected to be a function with the following signature:
@@ -2525,3 +2554,73 @@ def variable_creator_scope(variable_creator):
"""
with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access
yield
+
+
+# Note: only the docstrings differ between this and v1.
+@tf_export(v2=["variable_creator_scope"])
+@tf_contextlib.contextmanager
+def variable_creator_scope(variable_creator):
+ """Scope which defines a variable creation function to be used by variable().
+
+ variable_creator is expected to be a function with the following signature:
+
+ ```
+ def variable_creator(next_creator, **kwargs)
+ ```
+
+ The creator is supposed to eventually call the next_creator to create a
+ variable if it does want to create a variable and not call Variable or
+ ResourceVariable directly. This helps make creators composable. A creator may
+ choose to create multiple variables, return already existing variables, or
+ simply register that a variable was created and defer to the next creators in
+ line. Creators can also modify the keyword arguments seen by the next
+ creators.
+
+ Custom getters in the variable scope will eventually resolve down to these
+ custom creators when they do create variables.
+
+ The valid keyword arguments in kwds are:
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called. In
+ that case, `dtype` must be specified. (Note that initializer functions
+ from init_ops.py must first be bound to a shape before being used here.)
+ trainable: If `True`, the default, GradientTapes automatically watch
+ uses of this Variable.
+ validate_shape: If `False`, allows the variable to be initialized with a
+ value of unknown shape. If `True`, the default, the shape of
+ `initial_value` must be known.
+ caching_device: Optional device string describing where the Variable
+ should be cached for reading. Defaults to the Variable's device.
+ If not `None`, caches on another device. Typical use is to cache
+ on the device where the Ops using the Variable reside, to deduplicate
+ copying through `Switch` and other conditional statements.
+ name: Optional name for the variable. Defaults to `'Variable'` and gets
+ uniquified automatically.
+ dtype: If set, initial_value will be converted to the given type.
+ If `None`, either the datatype will be kept (if `initial_value` is
+ a Tensor), or `convert_to_tensor` will decide.
+ constraint: A constraint function to be applied to the variable after
+ updates by some algorithms.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ `tf.VariableSynchronization`. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ `tf.VariableAggregation`.
+
+ This set may grow over time, so it's important the signature of creators is as
+ mentioned above.
+
+ Args:
+ variable_creator: the passed creator
+
+ Yields:
+ A scope in which the creator is active
+ """
+ with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access
+ yield
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 7a46157739..8da1e9fe56 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -46,6 +46,11 @@ def default_variable_creator(_, **kwds):
raise NotImplementedError("variable_scope needs to be imported")
+def default_variable_creator_v2(_, **kwds):
+ del kwds
+ raise NotImplementedError("variable_scope needs to be imported")
+
+
def _make_getter(captured_getter, captured_previous):
"""To avoid capturing loop variables."""
def getter(**kwargs):
@@ -101,21 +106,21 @@ class VariableAggregation(enum.Enum):
class VariableMetaclass(type):
"""Metaclass to allow construction of tf.Variable to be overridden."""
- def _variable_call(cls,
- initial_value=None,
- trainable=None,
- collections=None,
- validate_shape=True,
- caching_device=None,
- name=None,
- variable_def=None,
- dtype=None,
- expected_shape=None,
- import_scope=None,
- constraint=None,
- use_resource=None,
- synchronization=VariableSynchronization.AUTO,
- aggregation=VariableAggregation.NONE):
+ def _variable_v1_call(cls,
+ initial_value=None,
+ trainable=None,
+ collections=None,
+ validate_shape=True,
+ caching_device=None,
+ name=None,
+ variable_def=None,
+ dtype=None,
+ expected_shape=None,
+ import_scope=None,
+ constraint=None,
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Call on Variable class. Useful to force the signature."""
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
@@ -140,14 +145,49 @@ class VariableMetaclass(type):
synchronization=synchronization,
aggregation=aggregation)
+ def _variable_v2_call(cls,
+ initial_value=None,
+ trainable=None,
+ validate_shape=True,
+ caching_device=None,
+ name=None,
+ variable_def=None,
+ dtype=None,
+ import_scope=None,
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
+ """Call on Variable class. Useful to force the signature."""
+ previous_getter = lambda **kws: default_variable_creator_v2(None, **kws)
+ for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
+ previous_getter = _make_getter(getter, previous_getter)
+
+ # Reset `aggregation` that is explicitly set as `None` to the enum NONE.
+ if aggregation is None:
+ aggregation = VariableAggregation.NONE
+ return previous_getter(
+ initial_value=initial_value,
+ trainable=trainable,
+ validate_shape=validate_shape,
+ caching_device=caching_device,
+ name=name,
+ variable_def=variable_def,
+ dtype=dtype,
+ import_scope=import_scope,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
+
def __call__(cls, *args, **kwargs):
- if cls is Variable:
- return cls._variable_call(*args, **kwargs)
+ if cls is VariableV1:
+ return cls._variable_v1_call(*args, **kwargs)
+ elif cls is Variable:
+ return cls._variable_v2_call(*args, **kwargs)
else:
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
-@tf_export("Variable")
+@tf_export(v2=["Variable"])
class Variable(six.with_metaclass(VariableMetaclass,
checkpointable.CheckpointableBase)):
"""See the [Variables Guide](https://tensorflow.org/guide/variables).
@@ -267,16 +307,13 @@ class Variable(six.with_metaclass(VariableMetaclass,
def __init__(self,
initial_value=None,
trainable=True,
- collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
- expected_shape=None,
import_scope=None,
constraint=None,
- use_resource=None,
synchronization=VariableSynchronization.AUTO,
aggregation=VariableAggregation.NONE):
"""Creates a new variable with value `initial_value`.
@@ -297,11 +334,8 @@ class Variable(six.with_metaclass(VariableMetaclass,
callable with no argument that returns the initial value when called. In
that case, `dtype` must be specified. (Note that initializer functions
from init_ops.py must first be bound to a shape before being used here.)
- trainable: If `True`, the default, also adds the variable to the graph
- collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
- the default list of variables to use by the `Optimizer` classes.
- collections: List of graph collections keys. The new variable is added to
- these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
+ trainable: If `True`, the default, GradientTapes automatically watch uses
+ of this variable.
validate_shape: If `False`, allows the variable to be initialized with a
value of unknown shape. If `True`, the default, the shape of
`initial_value` must be known.
@@ -319,8 +353,6 @@ class Variable(six.with_metaclass(VariableMetaclass,
dtype: If set, initial_value will be converted to the given type.
If `None`, either the datatype will be kept (if `initial_value` is
a Tensor), or `convert_to_tensor` will decide.
- expected_shape: A TensorShape. If set, initial_value is expected
- to have this shape.
import_scope: Optional `string`. Name scope to add to the
`Variable.` Only used when initializing from protocol buffer.
constraint: An optional projection function to be applied to the variable
@@ -330,9 +362,6 @@ class Variable(six.with_metaclass(VariableMetaclass,
variable and return the Tensor for the projected value
(which must have the same shape). Constraints are not safe to
use when doing asynchronous distributed training.
- use_resource: if True, a ResourceVariable is created; otherwise an
- old-style ref-based variable is created. When eager execution is enabled
- a resource variable is always created.
synchronization: Indicates when a distributed a variable will be
aggregated. Accepted values are constants defined in the class
`tf.VariableSynchronization`. By default the synchronization is set to
@@ -1009,11 +1038,207 @@ class Variable(six.with_metaclass(VariableMetaclass,
raise NotImplementedError
+@tf_export(v1=["Variable"])
+class VariableV1(Variable):
+ """See the [Variables Guide](https://tensorflow.org/guide/variables).
+
+ A variable maintains state in the graph across calls to `run()`. You add a
+ variable to the graph by constructing an instance of the class `Variable`.
+
+ The `Variable()` constructor requires an initial value for the variable,
+ which can be a `Tensor` of any type and shape. The initial value defines the
+ type and shape of the variable. After construction, the type and shape of
+ the variable are fixed. The value can be changed using one of the assign
+ methods.
+
+ If you want to change the shape of a variable later you have to use an
+ `assign` Op with `validate_shape=False`.
+
+ Just like any `Tensor`, variables created with `Variable()` can be used as
+ inputs for other Ops in the graph. Additionally, all the operators
+ overloaded for the `Tensor` class are carried over to variables, so you can
+ also add nodes to the graph by just doing arithmetic on variables.
+
+ ```python
+ import tensorflow as tf
+
+ # Create a variable.
+ w = tf.Variable(<initial-value>, name=<optional-name>)
+
+ # Use the variable in the graph like any Tensor.
+ y = tf.matmul(w, ...another variable or tensor...)
+
+ # The overloaded operators are available too.
+ z = tf.sigmoid(w + y)
+
+ # Assign a new value to the variable with `assign()` or a related method.
+ w.assign(w + 1.0)
+ w.assign_add(1.0)
+ ```
+
+ When you launch the graph, variables have to be explicitly initialized before
+ you can run Ops that use their value. You can initialize a variable by
+ running its *initializer op*, restoring the variable from a save file, or
+ simply running an `assign` Op that assigns a value to the variable. In fact,
+ the variable *initializer op* is just an `assign` Op that assigns the
+ variable's initial value to the variable itself.
+
+ ```python
+ # Launch the graph in a session.
+ with tf.Session() as sess:
+ # Run the variable initializer.
+ sess.run(w.initializer)
+ # ...you now can run ops that use the value of 'w'...
+ ```
+
+ The most common initialization pattern is to use the convenience function
+ `global_variables_initializer()` to add an Op to the graph that initializes
+ all the variables. You then run that Op after launching the graph.
+
+ ```python
+ # Add an Op to initialize global variables.
+ init_op = tf.global_variables_initializer()
+
+ # Launch the graph in a session.
+ with tf.Session() as sess:
+ # Run the Op that initializes global variables.
+ sess.run(init_op)
+ # ...you can now run any Op that uses variable values...
+ ```
+
+ If you need to create a variable with an initial value dependent on another
+ variable, use the other variable's `initialized_value()`. This ensures that
+ variables are initialized in the right order.
+
+ All variables are automatically collected in the graph where they are
+ created. By default, the constructor adds the new variable to the graph
+ collection `GraphKeys.GLOBAL_VARIABLES`. The convenience function
+ `global_variables()` returns the contents of that collection.
+
+ When building a machine learning model it is often convenient to distinguish
+ between variables holding the trainable model parameters and other variables
+ such as a `global step` variable used to count training steps. To make this
+ easier, the variable constructor supports a `trainable=<bool>` parameter. If
+ `True`, the new variable is also added to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES`. The convenience function
+ `trainable_variables()` returns the contents of this collection. The
+ various `Optimizer` classes use this collection as the default list of
+ variables to optimize.
+
+ WARNING: tf.Variable objects by default have a non-intuitive memory model. A
+ Variable is represented internally as a mutable Tensor which can
+ non-deterministically alias other Tensors in a graph. The set of operations
+ which consume a Variable and can lead to aliasing is undetermined and can
+ change across TensorFlow versions. Avoid writing code which relies on the
+ value of a Variable either changing or not changing as other operations
+ happen. For example, using Variable objects or simple functions thereof as
+ predicates in a `tf.cond` is dangerous and error-prone:
+
+ ```
+ v = tf.Variable(True)
+ tf.cond(v, lambda: v.assign(False), my_false_fn) # Note: this is broken.
+ ```
+
+ Here replacing adding `use_resource=True` when constructing the variable will
+ fix any nondeterminism issues:
+ ```
+ v = tf.Variable(True, use_resource=True)
+ tf.cond(v, lambda: v.assign(False), my_false_fn)
+ ```
+
+ To use the replacement for variables which does
+ not have these issues:
+
+ * Add `use_resource=True` when constructing `tf.Variable`;
+ * Call `tf.get_variable_scope().set_use_resource(True)` inside a
+ `tf.variable_scope` before the `tf.get_variable()` call.
+ """
+
+ def __init__(self, # pylint: disable=super-init-not-called
+ initial_value=None,
+ trainable=True,
+ collections=None,
+ validate_shape=True,
+ caching_device=None,
+ name=None,
+ variable_def=None,
+ dtype=None,
+ expected_shape=None,
+ import_scope=None,
+ constraint=None,
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
+ """Creates a new variable with value `initial_value`.
+
+ The new variable is added to the graph collections listed in `collections`,
+ which defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
+
+ If `trainable` is `True` the variable is also added to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES`.
+
+ This constructor creates both a `variable` Op and an `assign` Op to set the
+ variable to its initial value.
+
+ Args:
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called. In
+ that case, `dtype` must be specified. (Note that initializer functions
+ from init_ops.py must first be bound to a shape before being used here.)
+ trainable: If `True`, the default, also adds the variable to the graph
+ collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
+ the default list of variables to use by the `Optimizer` classes.
+ collections: List of graph collections keys. The new variable is added to
+ these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
+ validate_shape: If `False`, allows the variable to be initialized with a
+ value of unknown shape. If `True`, the default, the shape of
+ `initial_value` must be known.
+ caching_device: Optional device string describing where the Variable
+ should be cached for reading. Defaults to the Variable's device.
+ If not `None`, caches on another device. Typical use is to cache
+ on the device where the Ops using the Variable reside, to deduplicate
+ copying through `Switch` and other conditional statements.
+ name: Optional name for the variable. Defaults to `'Variable'` and gets
+ uniquified automatically.
+ variable_def: `VariableDef` protocol buffer. If not `None`, recreates
+ the Variable object with its contents, referencing the variable's nodes
+ in the graph, which must already exist. The graph is not changed.
+ `variable_def` and the other arguments are mutually exclusive.
+ dtype: If set, initial_value will be converted to the given type.
+ If `None`, either the datatype will be kept (if `initial_value` is
+ a Tensor), or `convert_to_tensor` will decide.
+ expected_shape: A TensorShape. If set, initial_value is expected
+ to have this shape.
+ import_scope: Optional `string`. Name scope to add to the
+ `Variable.` Only used when initializing from protocol buffer.
+ constraint: An optional projection function to be applied to the variable
+ after being updated by an `Optimizer` (e.g. used to implement norm
+ constraints or value constraints for layer weights). The function must
+ take as input the unprojected Tensor representing the value of the
+ variable and return the Tensor for the projected value
+ (which must have the same shape). Constraints are not safe to
+ use when doing asynchronous distributed training.
+ use_resource: whether to use resource variables.
+ synchronization: unused
+ aggregation: unused
+
+ Raises:
+ ValueError: If both `variable_def` and initial_value are specified.
+ ValueError: If the initial value is not specified, or does not have a
+ shape and `validate_shape` is `True`.
+ RuntimeError: If eager execution is enabled.
+ """
+
+ SaveSliceInfo = Variable.SaveSliceInfo
+
+
# TODO(apassos): do not repeat all comments here
-class RefVariable(Variable):
+class RefVariable(VariableV1):
"""Ref-based implementation of variables."""
- def __init__(self,
+ def __init__(self, # pylint: disable=super-init-not-called
initial_value=None,
trainable=True,
collections=None,
@@ -1873,7 +2098,7 @@ class RefVariable(Variable):
def _OverloadAllOperators(): # pylint: disable=invalid-name
"""Register overloads for all operators."""
for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
- Variable._OverloadOperator(operator)
+ Variable._OverloadOperator(operator) # pylint: disable=protected-access
# For slicing, bind getitem differently than a tensor (use SliceHelperVar
# instead)
# pylint: disable=protected-access
@@ -2401,7 +2626,7 @@ class PartitionedVariable(object):
"assign() has not been implemented for PartitionedVariable.")
-@tf_export("global_variables")
+@tf_export(v1=["global_variables"])
def global_variables(scope=None):
"""Returns global variables.
@@ -2427,7 +2652,7 @@ def global_variables(scope=None):
return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope)
-@tf_export("all_variables")
+@tf_export(v1=["all_variables"])
@deprecated("2017-03-02", "Please use tf.global_variables instead.")
def all_variables():
"""See `tf.global_variables`."""
@@ -2452,7 +2677,7 @@ def _all_saveable_objects(scope=None):
ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope))
-@tf_export("local_variables")
+@tf_export(v1=["local_variables"])
def local_variables(scope=None):
"""Returns local variables.
@@ -2480,7 +2705,7 @@ def local_variables(scope=None):
return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES, scope)
-@tf_export("model_variables")
+@tf_export(v1=["model_variables"])
def model_variables(scope=None):
"""Returns all variables in the MODEL_VARIABLES collection.
@@ -2497,7 +2722,7 @@ def model_variables(scope=None):
return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES, scope)
-@tf_export("trainable_variables")
+@tf_export(v1=["trainable_variables"])
def trainable_variables(scope=None):
"""Returns all variables created with `trainable=True`.
@@ -2519,7 +2744,7 @@ def trainable_variables(scope=None):
return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope)
-@tf_export("moving_average_variables")
+@tf_export(v1=["moving_average_variables"])
def moving_average_variables(scope=None):
"""Returns all variables that maintain their moving averages.
@@ -2541,7 +2766,7 @@ def moving_average_variables(scope=None):
return ops.get_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, scope)
-@tf_export("initializers.variables", "variables_initializer")
+@tf_export(v1=["initializers.variables", "variables_initializer"])
def variables_initializer(var_list, name="init"):
"""Returns an Op that initializes a list of variables.
@@ -2567,7 +2792,7 @@ def variables_initializer(var_list, name="init"):
return control_flow_ops.no_op(name=name)
-@tf_export("initialize_variables")
+@tf_export(v1=["initialize_variables"])
@tf_should_use.should_use_result
@deprecated("2017-03-02", "Use `tf.variables_initializer` instead.")
def initialize_variables(var_list, name="init"):
@@ -2575,7 +2800,7 @@ def initialize_variables(var_list, name="init"):
return variables_initializer(var_list, name=name)
-@tf_export("initializers.global_variables", "global_variables_initializer")
+@tf_export(v1=["initializers.global_variables", "global_variables_initializer"])
def global_variables_initializer():
"""Returns an Op that initializes global variables.
@@ -2589,7 +2814,7 @@ def global_variables_initializer():
return variables_initializer(global_variables())
-@tf_export("initialize_all_variables")
+@tf_export(v1=["initialize_all_variables"])
@tf_should_use.should_use_result
@deprecated("2017-03-02", "Use `tf.global_variables_initializer` instead.")
def initialize_all_variables():
@@ -2597,7 +2822,7 @@ def initialize_all_variables():
return global_variables_initializer()
-@tf_export("initializers.local_variables", "local_variables_initializer")
+@tf_export(v1=["initializers.local_variables", "local_variables_initializer"])
def local_variables_initializer():
"""Returns an Op that initializes all local variables.
@@ -2611,7 +2836,7 @@ def local_variables_initializer():
return variables_initializer(local_variables())
-@tf_export("initialize_local_variables")
+@tf_export(v1=["initialize_local_variables"])
@tf_should_use.should_use_result
@deprecated("2017-03-02", "Use `tf.local_variables_initializer` instead.")
def initialize_local_variables():
@@ -2619,7 +2844,7 @@ def initialize_local_variables():
return local_variables_initializer()
-@tf_export("is_variable_initialized")
+@tf_export(v1=["is_variable_initialized"])
@tf_should_use.should_use_result
def is_variable_initialized(variable):
"""Tests if a variable has been initialized.
@@ -2634,7 +2859,7 @@ def is_variable_initialized(variable):
return state_ops.is_variable_initialized(variable)
-@tf_export("assert_variables_initialized")
+@tf_export(v1=["assert_variables_initialized"])
@tf_should_use.should_use_result
def assert_variables_initialized(var_list=None):
"""Returns an Op to check if variables are initialized.
@@ -2677,7 +2902,7 @@ def assert_variables_initialized(var_list=None):
return array_ops.stack(ranks)
-@tf_export("report_uninitialized_variables")
+@tf_export(v1=["report_uninitialized_variables"])
@tf_should_use.should_use_result
def report_uninitialized_variables(var_list=None,
name="report_uninitialized_variables"):
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
new file mode 100644
index 0000000000..6791e1cd61
--- /dev/null
+++ b/tensorflow/python/ops/while_v2.py
@@ -0,0 +1,584 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""while_v2 and gradient.
+
+This is a version of while_loop that emits a single While op, as well as the
+gradient function for While ops produced by while_loop. This will eventually
+replace the current tf.while_loop implementation once it reaches feature and
+performance parity.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import sys
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.eager import function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import function_def_to_graph
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import cond_v2_impl as cond_v2
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import gen_functional_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import list_ops
+from tensorflow.python.util import nest
+
+# pylint: disable=protected-access
+
+control_flow_ops._while_v2 = sys.modules[__name__]
+
+# TODO(b/79881896): Handle external control dependencies. tf.while_loop allows
+# control dependencies on external nodes with at least 1 output.
+# Another idea is to create const nodes outside the loop and add control edges
+# to them and then pass those in as data inputs. This should probably be
+# handled in the CapturingGraph itself.
+
+
+def while_loop(cond, body, loop_vars, name=None):
+ """Like tf.while_loop, except emits a single While op."""
+ if not name:
+ name = "while"
+
+ with ops.name_scope(name) as scope:
+ with ops.name_scope(None):
+ cond_name = _get_unique_name(("%scond" % scope).replace("/", "_"))
+ body_name = _get_unique_name(("%sbody" % scope).replace("/", "_"))
+
+ flattened_loop_vars = nest.flatten(loop_vars)
+ num_outputs = len(flattened_loop_vars)
+
+ # Add loop counter needed for computing gradients.
+ flattened_loop_vars = [constant_op.constant(0., name="loop_counter")
+ ] + flattened_loop_vars
+
+ # Build a `cond` wrapper that can handle the extra counter loop_var.
+ def wrapped_cond(unused_loop_counter, *loop_vars):
+ return cond(*loop_vars)
+
+ cond_graph = function.func_graph_from_py_func(cond_name, wrapped_cond,
+ flattened_loop_vars, {})
+
+ # Add external_captures of cond to the list of loop vars.
+ # Note that external tensors will be treated as loop invariants, i.e.,
+ # the value of that tensor in each iteration is the same as it was at the
+ # beginning of the loop execution.
+ flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures
+
+ def wrapped_body(loop_counter, *args):
+ """Loop body augmented with counter update.
+
+ Args:
+ loop_counter: Loop counter which needs to be incremented in the body.
+ *args: List of args
+ args[:num_outputs] - Args for the original loop body.
+ args[num_outputs:] - External captures of cond. These get passed
+ through as is.
+
+ Returns:
+ A list of tensors the same length as args.
+ """
+ outputs = body(*args[:num_outputs])
+ if not isinstance(outputs, collections.Sequence):
+ outputs = [outputs]
+
+ # Return the external_captures of cond_graph as is, i.e., treat them as
+ # loop invariants.
+ # TODO(srbs): Update lowering code to create _Enter nodes with
+ # is_constant=True for inputs that are directly passed to outputs.
+ return [loop_counter + 1] + list(outputs) + list(args[num_outputs:])
+
+ body_graph = function.func_graph_from_py_func(body_name, wrapped_body,
+ flattened_loop_vars, {})
+ # Add external captures of body to the list of loop vars.
+ # Note that external tensors will be treated as loop invariants, i.e.,
+ # the value of that tensor in each iteration is the same as it was at the
+ # beginning of the loop execution.
+ flattened_loop_vars = flattened_loop_vars + body_graph.external_captures
+ # TODO(srbs): Update lowering code to create _Enter nodes with
+ # is_constant=True for inputs that are directly passed to outputs.
+ body_graph.outputs.extend(body_graph.internal_captures)
+
+ # Capture `external_captures` of `body_graph` in `cond_graph` so that it
+ # expects to receive those as arguments.
+ # TODO(srbs): Dedup tensors that are captured in both the cond and body.
+ # This logic already exists in cond_v2.
+ with cond_graph.as_default():
+ for external_capture in body_graph.external_captures:
+ cond_graph.capture(external_capture)
+
+ # Export all tensors in the loop body that may be needed for gradient
+ # computation. We do this by accumulating the intermediate values in
+ # TensorLists.
+ intermediate_tensors = _get_intermediates(body_graph)
+
+ for intermediate_tensor in intermediate_tensors:
+ # TODO(srbs): Cache and re-use empty tensor lists.
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=intermediate_tensor.dtype,
+ element_shape=_get_tensor_convertible_shape(
+ intermediate_tensor.shape))
+ flattened_loop_vars.append(tensor_list)
+ with cond_graph.as_default():
+ # Add a placeholder to cond_graph's inputs corresponding to the
+ # tensor_list.
+ cond_graph.capture(tensor_list)
+ with body_graph.as_default():
+ # Push the intermediate tensor to the tensor list. This captures the
+ # `tensor_list` as well.
+ appended_tensor_list = list_ops.tensor_list_push_back(
+ tensor_list,
+ intermediate_tensor)
+ # Add this modified tensor list to the list of outputs.
+ body_graph.outputs.append(appended_tensor_list)
+
+ outputs = gen_functional_ops._while(
+ flattened_loop_vars,
+ cond_v2._create_new_tf_function(cond_graph),
+ cond_v2._create_new_tf_function(body_graph),
+ name=scope)
+
+ _copy_handle_data(body_graph.outputs, outputs)
+ _maybe_set_lowering_attr(outputs[0].op)
+
+ # First var is loop counter.
+ if num_outputs == 1:
+ return outputs[1]
+ else:
+ return nest.pack_sequence_as(loop_vars, outputs[1:1 + num_outputs])
+
+
+@ops.RegisterGradient("While")
+def _WhileGrad(op, *grads): # pylint: disable=invalid-name
+ """The gradient of a While op produced by while_loop."""
+ body_graph = _get_body_graph(op)
+
+ # Replace None gradients with zeros. This is needed because `grads` could have
+ # None incoming gradients for the TensorLists. If we pass None's through, the
+ # custom gradient of TensorListPopBack will create an EmptyTensorList inside
+ # the FuncGraph which is undesirable.
+ # TODO(b/80444525): There might be an issue with treating no gradient as zero
+ # gradient in certain cases. Consider replacing None gradients with Zeros
+ # for accumulators only.
+ grads = [
+ g if g is not None else array_ops.zeros_like(output)
+ for g, output in zip(grads, op.outputs)
+ ]
+
+ body_grad_graph, args = _create_grad_func(
+ body_graph, grads,
+ _get_unique_name("%s_grad" % body_graph.name), op)
+
+ intermediate_tensors = _get_intermediates(body_grad_graph)
+
+ for intermediate_tensor in intermediate_tensors:
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=intermediate_tensor.dtype,
+ element_shape=_get_tensor_convertible_shape(intermediate_tensor.shape))
+ with body_grad_graph.as_default():
+ tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True)
+ # Push the intermediate tensor to the tensor list.
+ appended_tensor_list = list_ops.tensor_list_push_back(tensor_list_ph,
+ intermediate_tensor)
+ # Add this modified tensor list to the list of outputs.
+ body_grad_graph.outputs.append(appended_tensor_list)
+
+ def grad_cond(counter, max_iters, *unused_args):
+ return counter < max_iters
+
+ loop_vars = args + body_grad_graph.external_captures
+ cond_grad_graph = function.func_graph_from_py_func(
+ _get_unique_name("%s_grad_cond" % op.name),
+ grad_cond, loop_vars, {})
+
+ assert len(loop_vars) == len(body_grad_graph.inputs)
+ assert len(loop_vars) == len(body_grad_graph.outputs)
+ assert len(loop_vars) == len(cond_grad_graph.inputs)
+
+ outputs = gen_functional_ops._while(
+ loop_vars,
+ cond_v2._create_new_tf_function(cond_grad_graph),
+ cond_v2._create_new_tf_function(body_grad_graph),
+ name=_get_unique_name("%s_grad" % op.name))
+
+ _copy_handle_data(body_grad_graph.outputs, outputs)
+ _maybe_set_lowering_attr(outputs[0].op)
+
+ # outputs[0] is the loop counter.
+ # outputs[1] is the total number of loop iterations.
+ return outputs[2:2 + len(op.inputs)]
+
+
+# TODO(srbs): Pull this into common utils for cond_v2 and while_v2.
+def _get_body_graph(while_op):
+ """Returns `FuncGraph` for the while body.
+
+ Args:
+ while_op: The While Operation.
+
+ Returns:
+ `FuncGraph` for the while body.
+ """
+ extra_inputs = list(while_op.inputs)
+ input_shapes = [t.shape for t in extra_inputs]
+ func_name = while_op.get_attr("body").name
+ fdef = while_op.graph._get_function(func_name).definition
+ func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes)
+ func_graph._while = while_op
+ return func_graph
+
+
+def _create_grad_func(func_graph, grads, name, while_op):
+ """Builds and returns the gradient FuncGraph of `func_graph` and its args.
+
+ The returned grad_func_graph must be called with the returned
+ args + grad_func_graph.captures.
+
+ Args:
+ func_graph: FuncGraph for the forward body function.
+ grads: The incoming grads for `func_graph`'s outputs.
+ name: Name of the returned gradient function.
+ while_op: The forward While op.
+
+ Returns:
+ 2-tuple of (grad_func_graph, args).
+ """
+ assert len(func_graph.outputs) == len(grads)
+
+ loop_counter = constant_op.constant(0.)
+ # TODO(srbs): For nested while loops will need to lookup this value from
+ # the accumulator of the enclosing while loop. For now use as is assuming
+ # there is no nesting.
+ num_iters_t = while_op.outputs[0]
+
+ args = [loop_counter, num_iters_t] + grads
+
+ # Note: The returned function does not have `args` in the list of
+ # `external_captures`.
+ grad_func_graph = function.func_graph_from_py_func(
+ name,
+ lambda *args: _grad_fn(func_graph, args),
+ args, {},
+ func_graph=_WhileBodyGradFuncGraph(name, func_graph))
+
+ # Add the popped accumulators to the list of outputs.
+ for internal_capture in grad_func_graph.internal_captures:
+ grad_func_graph.outputs.append(
+ grad_func_graph.popped_tensor_lists[internal_capture])
+
+ return grad_func_graph, args
+
+
+def _grad_fn(func_graph, args):
+ """Computes the gradient of `func_graph` in the current graph.
+
+ This function builds the gradient graph of the corresponding forward-pass
+ `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs.
+
+ Args:
+ func_graph: function.FuncGraph. The corresponding forward-pass function.
+ args: The input arguments. args[0] - Loop counter args[1] - Total number of
+ iterations.
+ args[2:] - Incoming gradients for `func_graph.outputs`.
+
+ Returns:
+ The output gradient Tensors.
+ """
+ xs = func_graph.inputs
+ ys = func_graph.outputs
+ grad_ys = args[2:]
+
+ # Build the gradient graph. Note that this builds the gradient computation of
+ # func_graph in the current graph, which requires capturing tensors from
+ # func_graph. The captured func_graph tensors are resolved to external tensors
+ # in _resolve_grad_inputs.
+ # TODO(srbs): Mark GradientsHelper as public?
+ grad_outs = gradients_impl._GradientsHelper(
+ ys, xs, grad_ys=grad_ys, src_graph=func_graph)
+
+ assert all([g is not None for g in grad_outs])
+ counter = args[0]
+ total_iters = args[1]
+ return [counter + 1, total_iters] + grad_outs
+
+
+def _get_intermediates(func_graph):
+ """Returns all tensors in `func_graph` that should be accumulated."""
+ # We currently accumulate output tensors of most ops in the function and rely
+ # on the pruning pass to get rid of the unused accumulators at runtime.
+ # However, this can bloat the GraphDef and make debugging harder so we perform
+ # some optimizations.
+ #
+ # Optimization we currently perform:
+ # 1. We do not accumulate tensors which already have an accumulator
+ # in the loop body.
+ # 2. We do not accumulate outputs of Identity nodes. When building the
+ # FuncGraph, we add an Identity node for each output (see
+ # `AutomaticControlDependencies.mark_as_return`). Accumulating outputs
+ # of all these nodes bloats the GraphDef quite a bit so we remove those.
+ # Since the gradient of an Identity node does not rely on its forward op's
+ # input this is safe to do.
+ #
+ # Other possible optimizations:
+ # 1. Only accumulate tensors that will be required by the backward pass.
+ # This will require running the gradient pass and hence would increase the
+ # graph building time for the forward pass.
+ # 2. Do not accumulate Const nodes created inside the loop body.
+ # 3. Do not accumulate inputs that are passed as-is, e.g. loop invariants.
+ # TODO(srbs): 2 and 3 may be hard optimizations for the runtime optimizer
+ # since it requires knowledge of the while loop semantics. If so, consider
+ # doing those here.
+ intermediates = []
+
+ for op in func_graph.get_operations():
+ if op.type == "Identity":
+ continue
+ for o in op.outputs:
+ if (o != func_graph.inputs[0] and # Loop counter.
+ _get_accumulator(o) is None): # Has existing accumulator.
+ intermediates.append(o)
+ return intermediates
+
+
+def _get_accumulator(tensor):
+ r"""Returns TensorList if any containing accumulated values of tensor.
+
+ We try to find a pattern of the form:
+
+ input_tl tensor
+ \ /
+ (TensorListPushBack)
+ |
+ output_tl
+
+ which satisfies the following conditions:
+
+ 1. input_tl must be in tensor.graph.inputs.
+ 2. output_tl or Identity(output_tl) must be in tensor.graph.outputs.
+ 3. tensor.graph.input_index(input_tl) == tensor.graph.output_index(output_t).
+
+ output_tl or Identity(output_tl) (whichever is in tensor.graph.outputs) is
+ returned if such a pattern is found else None is returned.
+
+ Args:
+ tensor: The Tensor to be accumulated.
+
+ Returns:
+ A variant tensor in the same graph as `tensor` or None if no accumulator is
+ found.
+ """
+ assert isinstance(tensor.graph, function.FuncGraph)
+
+ def get_func_graph_output(t):
+ """Returns t or Identity(t) whichever exists in graph outputs else None."""
+ if t in tensor.graph.outputs:
+ return t
+ # tf.defun adds an Identity for each output, check whether that is the case.
+ identity_op = t.consumers()[0]
+ if (identity_op.type == "Identity" and
+ identity_op.outputs[0] in tensor.graph.outputs):
+ return identity_op.outputs[0]
+ return None
+
+ for consumer in tensor.consumers():
+ # Find the consumer that is a TensorListPushBack node whose TensorList input
+ # is in the list of function inputs.
+ if (consumer.type != "TensorListPushBack" or
+ consumer.inputs[0] not in tensor.graph.inputs):
+ continue
+
+ output = get_func_graph_output(consumer.outputs[0])
+ if output is None:
+ # The TensorList output of `consumer` is not in the list of function
+ # outputs.
+ continue
+
+ accum_input_idx = tensor.graph.inputs.index(consumer.inputs[0])
+ accum_output_idx = tensor.graph.outputs.index(output)
+ if accum_input_idx == accum_output_idx:
+ return output
+ return None
+
+
+# TODO(srbs): Add to common utils for cond_v2 and while_v2.
+def _get_unique_name(name):
+ """Returns a name that is unique in the root graph of `func_graph`.
+
+ Args:
+ name: String to uniquify.
+
+ Returns:
+ A string.
+ """
+ with ops.init_scope():
+ return ops.get_default_graph().unique_name(name)
+
+
+class _WhileBodyGradFuncGraph(function.FuncGraph):
+ """FuncGraph for the gradient function of the body of a While op.
+
+ Contains the logic for capturing the tensors from the body of the forward
+ While op which is as follows:
+ 1. Find the accumulator for that tensor.
+ 2. Capture the forward While op output tensor corresponding to the
+ accumulator in this FuncGraph.
+ 3. Pop a value from the captured placeholder and use it as the captured value
+ for the forward pass tensor.
+
+ This only allows capturing tensors in the forward graph. A ValueError is
+ raised if an attempt is made to capture a tensor not in the forward graph.
+ To manually capture capture a tensor that is not in the forward graph, call
+ `capture` with `whitelisted=True`.
+
+ Note: The `captures` dict does not contain the forward tensor since it is not
+ directly captured. It contains the accumulator corresponding to this forward
+ tensor.
+
+ Attributes:
+ popped_tensor_lists: Dict from the captured accumulator placeholder to the
+ TensorList obtained after popping the intermediate tensor from it. The
+ values of this dict need to be added to the list of outputs.
+ """
+
+ def __init__(self, name, forward_graph):
+ super(_WhileBodyGradFuncGraph, self).__init__(name)
+ self.popped_tensor_lists = {}
+ # FuncGraph for the body of the forward While op.
+ self._forward_graph = forward_graph
+ # Dict from forward intermediate tensor to the corresponding "popped" tensor
+ # in this graph.
+ self._indirect_captures = {}
+ # Dict from forward graph tensor to the While op output corresponding to its
+ # accumulator.
+ self._tensor_to_accumulator = {}
+
+ def capture(self, tensor, name=None, whitelisted=False):
+ """Selectively captures external tensors.
+
+ If `whitelisted` is False only allows capturing tensors in the
+ `_forward_graph`.
+
+ Args:
+ tensor: Tensor. May be from this FuncGraph or a different graph.
+ name: Optional name if a placeholder is created.
+ whitelisted: If False (default), only allows capturing tensors from the
+ forward graph.
+
+ Returns:
+ The placeholder in this graph for the tensor.
+
+ Raises:
+ ValueError: If attempting to capture an external tensor not in the forward
+ graph with `whitelisted` set to False.
+ """
+ if (not whitelisted and tensor.graph is not self and
+ tensor.graph != self._forward_graph):
+ raise ValueError("Attempting to capture tensor", str(tensor),
+ " which is not in the forward graph but in ",
+ _graph_name(tensor.graph), ".")
+ return super(_WhileBodyGradFuncGraph, self).capture(tensor, name)
+
+ def _capture_helper(self, tensor, name):
+ if tensor.graph is not self._forward_graph:
+ return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name)
+
+ captured_tensor = self._indirect_captures.get(tensor)
+ if captured_tensor is not None:
+ # For GradientTape housekeeping.
+ assert self._tensor_to_accumulator[tensor] in self.captures
+ super(_WhileBodyGradFuncGraph, self)._capture_helper(
+ self._tensor_to_accumulator[tensor], name)
+ return captured_tensor
+
+ assert tensor not in self._tensor_to_accumulator
+
+ accumulator = None
+
+ # Find the TensorList that was used to accumulate the tensors of this
+ # intermediate tensor.
+ accumulator = _get_accumulator(tensor)
+ if accumulator is None:
+ raise ValueError("Reference to un-accumulated intermediate tensor: ",
+ tensor.name)
+ assert accumulator.graph == self._forward_graph
+ # Get the While op output corresponding to the accumulator.
+ accumulator = self._forward_graph._while.outputs[self._forward_graph.outputs
+ .index(accumulator)]
+
+ assert accumulator.graph == self._forward_graph.outer_graph
+ self._tensor_to_accumulator[tensor] = accumulator
+
+ # Capture the `accumulator`.
+ accumulator_ph = super(_WhileBodyGradFuncGraph, self)._capture_helper(
+ accumulator, name)
+ new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
+ accumulator_ph, element_dtype=tensor.dtype)
+ self._indirect_captures[tensor] = captured_tensor
+ self.popped_tensor_lists[accumulator_ph] = new_tensor_list
+ return captured_tensor
+
+
+def _copy_handle_data(src_tensors, tgt_tensors):
+ for src_t, tgt_t in zip(src_tensors, tgt_tensors):
+ function._copy_handle_data(src_t, tgt_t)
+
+
+# TODO(srbs): Move to common utils for cond_v2 and while_v2.
+def _maybe_set_lowering_attr(op):
+ """Sets the flag to enable lowering on the `While` op if necessary.
+
+ Lowering allows while_v2 to avoid some of the limitations of Functions,
+ allowing users to specify devices & colocation inside of while_v2
+ branches, and enabling non-strict evaluation & partial pruning of while_v2
+ branches. This brings while_v2 closer to feature parity with
+ tf.while_loop.
+
+ However, we do not lower `While` in the XLA context because it is easier
+ for XLA to apply its own optimizations when dealing with un-lowered
+ `While` operators than with low-level control flow primitives.
+
+ Args:
+ op: The While op.
+ """
+ if not control_flow_util.IsInXLAContext(op):
+ # pylint: disable=protected-access
+ op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True))
+ # pylint: enable=protected-access
+
+
+def _get_tensor_convertible_shape(shape):
+ assert isinstance(shape, tensor_shape.TensorShape)
+ if shape.is_fully_defined():
+ return shape
+ if not shape: # Unknown shape.
+ return -1
+ # Partially defined shape.
+ shape_list = shape.as_list()
+ shape_list = [s if s is not None else -1 for s in shape_list]
+ return ops.convert_to_tensor(shape_list)
+
+
+def _graph_name(graph):
+ if isinstance(graph, function.FuncGraph):
+ return graph.name
+ return "Base"
+
+
+# pylint: enable=protected-access
diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py
index c0e16ca536..94c685274a 100644
--- a/tensorflow/python/profiler/model_analyzer_test.py
+++ b/tensorflow/python/profiler/model_analyzer_test.py
@@ -52,13 +52,19 @@ builder = option_builder.ProfileOptionBuilder
class PrintModelAnalysisTest(test.TestCase):
+ def _no_rewrite_session_config(self):
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF)
+ graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
+ return config_pb2.ConfigProto(graph_options=graph_options)
+
def testDumpToFile(self):
ops.reset_default_graph()
outfile = os.path.join(test.get_temp_dir(), 'dump')
opts = builder(builder.trainable_variables_parameter()
).with_file_output(outfile).build()
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
_ = lib.BuildSmallModel()
model_analyzer.profile(sess.graph, options=opts)
@@ -83,7 +89,8 @@ class PrintModelAnalysisTest(test.TestCase):
with profile_context.ProfileContext(test.get_temp_dir(),
trace_steps=[],
dump_steps=[]) as pctx:
- with session.Session() as sess, ops.device(dev):
+ with session.Session(
+ config=self._no_rewrite_session_config()) as sess, ops.device(dev):
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
@@ -149,11 +156,8 @@ class PrintModelAnalysisTest(test.TestCase):
.select(['params', 'float_ops', 'occurrence', 'device', 'op_types',
'input_shapes']).build())
- rewriter_config = rewriter_config_pb2.RewriterConfig(
- disable_model_pruning=True)
- graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
- config = config_pb2.ConfigProto(graph_options=graph_options)
- with session.Session(config=config) as sess, ops.device('/device:CPU:0'):
+ with session.Session(config=self._no_rewrite_session_config()
+ ) as sess, ops.device('/device:CPU:0'):
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
@@ -179,7 +183,7 @@ class PrintModelAnalysisTest(test.TestCase):
.select(['bytes', 'params', 'float_ops', 'num_hidden_ops', 'device',
'input_shapes']).build())
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
@@ -213,7 +217,7 @@ class PrintModelAnalysisTest(test.TestCase):
with profile_context.ProfileContext(test.get_temp_dir(),
trace_steps=[],
dump_steps=[]) as pctx:
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -274,7 +278,7 @@ class PrintModelAnalysisTest(test.TestCase):
.account_displayed_op_only(False)
.select(['bytes', 'params', 'float_ops', 'device']).build())
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
@@ -302,7 +306,7 @@ class PrintModelAnalysisTest(test.TestCase):
.with_timeline_output(outfile)
.with_accounted_types(['.*']).build())
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -338,7 +342,7 @@ class PrintModelAnalysisTest(test.TestCase):
'peak_bytes', 'residual_bytes',
'output_bytes', 'occurrence', 'input_shapes']).build())
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -384,7 +388,7 @@ class PrintModelAnalysisTest(test.TestCase):
def testAdvisor(self):
ops.reset_default_graph()
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -417,7 +421,7 @@ class PrintModelAnalysisTest(test.TestCase):
.with_node_names(trim_name_regexes=['ops.py.*'])
.with_pprof_output(outfile).build())
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -484,7 +488,7 @@ class PrintModelAnalysisTest(test.TestCase):
self.assertGreaterEqual(n.output_bytes, mob)
check_min(n.children, mm, mam, mcm, mb, mpb, mrb, mob)
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
run_meta = config_pb2.RunMetadata()
@@ -549,7 +553,7 @@ class PrintModelAnalysisTest(test.TestCase):
for attr in not_selected:
self.assertFalse(s.find(attr) > 0, s)
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
run_meta = config_pb2.RunMetadata()
@@ -582,7 +586,7 @@ class PrintModelAnalysisTest(test.TestCase):
def _trainLoop(self, train_op, train_steps, time_dir, time_step,
memory_dir, memory_step, profile_dir, dump_step):
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
sess.run(variables.global_variables_initializer())
# start from 1 because variable_initializer took one step.
for i in range(1, train_steps + 1):
@@ -655,7 +659,7 @@ class PrintModelAnalysisTest(test.TestCase):
c = a * b
try:
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
sess.run(c, options=config_pb2.RunOptions(
report_tensor_allocations_upon_oom=True))
except Exception as e: # pylint: disable=broad-except
@@ -758,7 +762,7 @@ class PrintModelAnalysisTest(test.TestCase):
grad = gradients.gradients(y, [x1])
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
diff --git a/tensorflow/python/profiler/pprof_profiler_test.py b/tensorflow/python/profiler/pprof_profiler_test.py
index c2469f012d..11a3487360 100644
--- a/tensorflow/python/profiler/pprof_profiler_test.py
+++ b/tensorflow/python/profiler/pprof_profiler_test.py
@@ -141,7 +141,7 @@ comment: 9
run_metadata = config_pb2.RunMetadata()
num_iters = 5
- with self.test_session() as sess:
+ with self.cached_session() as sess:
i = constant_op.constant(0)
c = lambda i: math_ops.less(i, num_iters)
b = lambda i: math_ops.add(i, 1)
diff --git a/tensorflow/python/pywrap_tensorflow.py b/tensorflow/python/pywrap_tensorflow.py
index 5c0c5783dc..f0724277d3 100644
--- a/tensorflow/python/pywrap_tensorflow.py
+++ b/tensorflow/python/pywrap_tensorflow.py
@@ -68,7 +68,7 @@ try:
sys.setdlopenflags(_default_dlopen_flags)
except ImportError:
msg = """%s\n\nFailed to load the native TensorFlow runtime.\n
-See https://www.tensorflow.org/install/install_sources#common_installation_problems\n
+See https://www.tensorflow.org/install/errors\n
for some common reasons and solutions. Include the entire stack trace
above this error message when asking for help.""" % traceback.format_exc()
raise ImportError(msg)
diff --git a/tensorflow/python/saved_model/loader_test.py b/tensorflow/python/saved_model/loader_test.py
index b7e217a35b..924b2e7c06 100644
--- a/tensorflow/python/saved_model/loader_test.py
+++ b/tensorflow/python/saved_model/loader_test.py
@@ -47,8 +47,8 @@ class SavedModelLoaderTest(test.TestCase):
def setUp(self):
"""Write test SavedModels to a temp directory."""
with session.Session(graph=ops.Graph()) as sess:
- x = variables.Variable(5, name="x")
- y = variables.Variable(11, name="y")
+ x = variables.VariableV1(5, name="x")
+ y = variables.VariableV1(11, name="y")
z = x + y
sess.run(variables.global_variables_initializer())
@@ -134,8 +134,8 @@ class SavedModelLoaderTest(test.TestCase):
def test_restore_variables(self):
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
with self.session(graph=ops.Graph()) as sess:
- x = variables.Variable(0, name="x")
- y = variables.Variable(0, name="y")
+ x = variables.VariableV1(0, name="x")
+ y = variables.VariableV1(0, name="y")
z = x * y
sess.run(variables.global_variables_initializer())
@@ -186,8 +186,10 @@ class SavedModelLoaderTest(test.TestCase):
"""
path = _get_export_dir("no_variable_saved_model")
with session.Session(graph=ops.Graph()) as sess:
- x = variables.Variable(5, name="x", collections=["not_global_variable"])
- y = variables.Variable(11, name="y", collections=["not_global_variable"])
+ x = variables.VariableV1(
+ 5, name="x", collections=["not_global_variable"])
+ y = variables.VariableV1(
+ 11, name="y", collections=["not_global_variable"])
self.assertFalse(variables._all_saveable_objects())
z = x + y
sess.run(variables.variables_initializer([x, y]))
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 49d52d3bee..80b75b7ee6 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -60,7 +60,7 @@ class SavedModelTest(test.TestCase):
return os.path.join(test.get_temp_dir(), label)
def _init_and_validate_variable(self, sess, variable_name, variable_value):
- v = variables.Variable(variable_value, name=variable_name)
+ v = variables.VariableV1(variable_value, name=variable_name)
sess.run(variables.global_variables_initializer())
self.assertEqual(variable_value, v.eval())
@@ -458,7 +458,7 @@ class SavedModelTest(test.TestCase):
# Graph with a single variable added to a collection. SavedModel invoked to:
# - add with weights.
with self.session(graph=ops.Graph()) as sess:
- v = variables.Variable(42, name="v")
+ v = variables.VariableV1(42, name="v")
ops.add_to_collection("foo_vars", v)
sess.run(variables.global_variables_initializer())
self.assertEqual(42, v.eval())
@@ -468,7 +468,7 @@ class SavedModelTest(test.TestCase):
# SavedModel invoked to:
# - simply add the model (weights are not updated).
with self.session(graph=ops.Graph()) as sess:
- v = variables.Variable(43, name="v")
+ v = variables.VariableV1(43, name="v")
ops.add_to_collection("bar_vars", v)
sess.run(variables.global_variables_initializer())
self.assertEqual(43, v.eval())
@@ -780,13 +780,13 @@ class SavedModelTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
- v2 = variables.Variable(2, name="v2")
+ v2 = variables.VariableV1(2, name="v2")
ops.add_to_collection("v", v2)
# Initialize another variable `v3` to 42.
- v3 = variables.Variable(42, name="v3")
+ v3 = variables.VariableV1(42, name="v3")
ops.add_to_collection("v", v3)
# Set up an assignment op to be run as part of the main_op.
@@ -815,13 +815,13 @@ class SavedModelTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
- v2 = variables.Variable(2, name="v2")
+ v2 = variables.VariableV1(2, name="v2")
ops.add_to_collection("v", v2)
# Initialize another variable `v3` to 42.
- v3 = variables.Variable(42, name="v3", trainable=False, collections=[])
+ v3 = variables.VariableV1(42, name="v3", trainable=False, collections=[])
ops.add_to_collection("v", v3)
# Set up an assignment op to be run as part of the legacy_init_op.
@@ -860,11 +860,11 @@ class SavedModelTest(test.TestCase):
g = ops.Graph()
with self.session(graph=g) as sess:
# Initialize variable `v1` to 1.
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
# Initialize another variable `v2` to 42.
- v2 = variables.Variable(42, name="v2", trainable=False, collections=[])
+ v2 = variables.VariableV1(42, name="v2", trainable=False, collections=[])
ops.add_to_collection("v", v2)
# Set up an assignment op to be run as part of the init op.
@@ -889,9 +889,9 @@ class SavedModelTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
- v2 = variables.Variable(2, name="v2")
+ v2 = variables.VariableV1(2, name="v2")
ops.add_to_collection("v", v2)
sess.run(variables.global_variables_initializer())
@@ -918,9 +918,9 @@ class SavedModelTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
- v2 = variables.Variable(2, name="v2")
+ v2 = variables.VariableV1(2, name="v2")
ops.add_to_collection("v", v2)
sess.run(variables.global_variables_initializer())
@@ -947,9 +947,9 @@ class SavedModelTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
- v2 = variables.Variable(2, name="v2")
+ v2 = variables.VariableV1(2, name="v2")
ops.add_to_collection("v", v2)
sess.run(variables.global_variables_initializer())
@@ -1071,13 +1071,13 @@ class SavedModelTest(test.TestCase):
graph=ops.Graph(),
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
with sess.graph.device("/cpu:1"):
- v2 = variables.Variable(2, name="v2")
+ v2 = variables.VariableV1(2, name="v2")
# v3 is an unsaved variable derived from v1 and v2. It is used to
# exercise the ability to run an init op when restoring a graph.
- v3 = variables.Variable(1, name="v3", trainable=False, collections=[])
+ v3 = variables.VariableV1(1, name="v3", trainable=False, collections=[])
assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2))
init_op = control_flow_ops.group(assign_v3, name="init_op")
@@ -1140,7 +1140,7 @@ class SavedModelTest(test.TestCase):
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.session(graph=ops.Graph()) as sess:
- variables.Variable(1, name="v1")
+ variables.VariableV1(1, name="v1")
sess.run(variables.global_variables_initializer())
custom_saver = training.Saver(name="my_saver")
builder.add_meta_graph_and_variables(sess, ["tag"], saver=custom_saver)
@@ -1162,7 +1162,7 @@ class SavedModelTest(test.TestCase):
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.session(graph=ops.Graph()) as sess:
- variables.Variable(1, name="v1")
+ variables.VariableV1(1, name="v1")
sess.run(variables.global_variables_initializer())
training.Saver(name="my_saver")
builder.add_meta_graph_and_variables(sess, ["tag"])
@@ -1184,7 +1184,7 @@ class SavedModelTest(test.TestCase):
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.session(graph=ops.Graph()) as sess:
- variables.Variable(1, name="v1")
+ variables.VariableV1(1, name="v1")
sess.run(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(sess, ["tag_0"])
@@ -1293,8 +1293,8 @@ class SavedModelTest(test.TestCase):
# Add a graph with two float32 variables and a Complex Op composing them
# with strip_default_attrs enabled.
with session.Session(graph=ops.Graph()) as sess:
- real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
- imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+ real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
+ imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
sess.run(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(
@@ -1303,8 +1303,8 @@ class SavedModelTest(test.TestCase):
# Add a graph with the same float32 variables and a Complex Op composing
# them with strip_default_attrs disabled.
with session.Session(graph=ops.Graph()) as sess:
- real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
- imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+ real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
+ imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
sess.run(variables.global_variables_initializer())
builder.add_meta_graph(["bar"], strip_default_attrs=False)
@@ -1366,7 +1366,7 @@ class SavedModelTest(test.TestCase):
# Add a graph with a single variable and a test op with a defaultless
# float32 attr, "test_attr".
with session.Session(graph=ops.Graph()) as sess:
- variables.Variable(1.0, dtype=dtypes.float64, name="var")
+ variables.VariableV1(1.0, dtype=dtypes.float64, name="var")
test_ops.test_attr(T=dtypes.float32, name="test_attr")
sess.run(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(sess, ["foo"])
diff --git a/tensorflow/python/summary/writer/writer_test.py b/tensorflow/python/summary/writer/writer_test.py
index dc990c2602..670230e917 100644
--- a/tensorflow/python/summary/writer/writer_test.py
+++ b/tensorflow/python/summary/writer/writer_test.py
@@ -286,7 +286,7 @@ class FileWriterTestCase(test.TestCase):
def testAddingSummariesFromSessionRunCalls(self):
test_dir = self._CleanTestDir("global_step")
sw = self._FileWriter(test_dir)
- with self.test_session():
+ with self.cached_session():
i = constant_op.constant(1, dtype=dtypes.int32, shape=[])
l = constant_op.constant(2, dtype=dtypes.int64, shape=[])
# Test the summary can be passed serialized.
@@ -437,7 +437,7 @@ class SessionBasedFileWriterTestCase(FileWriterTestCase):
# Pass in test_session() as the session. It will be cached during this
# test method invocation so that any other use of test_session() with no
# graph should result in re-using the same underlying Session.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
kwargs["session"] = sess
return writer.FileWriter(*args, **kwargs)
return writer.FileWriter(*args, **kwargs)
diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD
index 1c1a1a54cd..384c7a82d2 100644
--- a/tensorflow/python/tools/BUILD
+++ b/tensorflow/python/tools/BUILD
@@ -8,6 +8,7 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow:tensorflow.bzl", "py_binary")
# Transitive dependencies of this target will be included in the pip package.
py_library(
@@ -21,6 +22,13 @@ py_library(
":saved_model_cli",
":saved_model_utils",
":strip_unused",
+ # The following py_library are needed because
+ # py_binary may not depend on them when --define=no_tensorflow_py_deps=true
+ # is specified. See https://github.com/tensorflow/tensorflow/issues/22390
+ ":freeze_graph_lib",
+ ":optimize_for_inference_lib",
+ ":selective_registration_header_lib",
+ ":strip_unused_lib",
],
)
@@ -44,6 +52,7 @@ py_library(
"//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
"//tensorflow/python:training",
+ "//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/saved_model:loader",
"@six_archive//:six",
],
diff --git a/tensorflow/python/tools/api/generator/create_python_api.py b/tensorflow/python/tools/api/generator/create_python_api.py
index 67cfd799ff..ab749f28cd 100644
--- a/tensorflow/python/tools/api/generator/create_python_api.py
+++ b/tensorflow/python/tools/api/generator/create_python_api.py
@@ -181,7 +181,6 @@ class _ModuleInitCodeBuilder(object):
_names_with_underscore = [%s]
__all__ = [_s for _s in dir() if not _s.startswith('_')]
__all__.extend([_s for _s in _names_with_underscore])
-__all__.remove('print_function')
''' % underscore_names_str
return module_text_map
diff --git a/tensorflow/python/tools/freeze_graph_test.py b/tensorflow/python/tools/freeze_graph_test.py
index e38945fabc..5dc14a6961 100644
--- a/tensorflow/python/tools/freeze_graph_test.py
+++ b/tensorflow/python/tools/freeze_graph_test.py
@@ -60,7 +60,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
# We'll create an input graph that has a single variable containing 1.0,
# and that then multiplies it by 2.
with ops.Graph().as_default():
- variable_node = variables.Variable(1.0, name="variable_node")
+ variable_node = variables.VariableV1(1.0, name="variable_node")
output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
sess = session.Session()
init = variables.global_variables_initializer()
@@ -138,7 +138,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
features = parsing_ops.parse_example(examples, feature_configs)
feature = features[feature_name]
- variable_node = variables.Variable(1.0, name="variable_node")
+ variable_node = variables.VariableV1(1.0, name="variable_node")
scores = math_ops.multiply(variable_node, feature, name="output_node")
class_feature = array_ops.fill(array_ops.shape(feature),
"class_%s" % feature_name)
@@ -174,7 +174,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
output_graph_filename = os.path.join(tmp_dir, "output_graph.pb")
with ops.Graph().as_default():
- variable_node = variables.Variable(1.0, name="variable_node")
+ variable_node = variables.VariableV1(1.0, name="variable_node")
output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
sess = session.Session()
init = variables.global_variables_initializer()
diff --git a/tensorflow/python/tools/optimize_for_inference_test.py b/tensorflow/python/tools/optimize_for_inference_test.py
index fcb3ceac82..a39c046761 100644
--- a/tensorflow/python/tools/optimize_for_inference_test.py
+++ b/tensorflow/python/tools/optimize_for_inference_test.py
@@ -129,7 +129,7 @@ class OptimizeForInferenceTest(test.TestCase):
self.assertProtoEquals(expected_output, output)
def testFoldBatchNorms(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
input_op = constant_op.constant(
np.array(inputs), shape=[1, 1, 6, 2], dtype=dtypes.float32)
@@ -161,7 +161,7 @@ class OptimizeForInferenceTest(test.TestCase):
optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
original_graph_def)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_ = importer.import_graph_def(
optimized_graph_def, input_map={}, name="optimized")
optimized_result = sess.run(["optimized/output:0"])
@@ -224,7 +224,7 @@ class OptimizeForInferenceTest(test.TestCase):
self.assertNotEqual("FusedBatchNorm", node.op)
def testFuseResizePadAndConv(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
input_op = constant_op.constant(
np.array(inputs), shape=[1, 2, 3, 2], dtype=dtypes.float32)
@@ -242,7 +242,7 @@ class OptimizeForInferenceTest(test.TestCase):
optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
original_graph_def, ["output"])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_ = importer.import_graph_def(
optimized_graph_def, input_map={}, name="optimized")
optimized_result = sess.run(["optimized/output:0"])
@@ -255,7 +255,7 @@ class OptimizeForInferenceTest(test.TestCase):
self.assertNotEqual("ResizeBilinear", node.op)
def testFuseResizeAndConv(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
input_op = constant_op.constant(
np.array(inputs), shape=[1, 2, 3, 2], dtype=dtypes.float32)
@@ -271,7 +271,7 @@ class OptimizeForInferenceTest(test.TestCase):
optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
original_graph_def, ["output"])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_ = importer.import_graph_def(
optimized_graph_def, input_map={}, name="optimized")
optimized_result = sess.run(["optimized/output:0"])
@@ -284,7 +284,7 @@ class OptimizeForInferenceTest(test.TestCase):
def testFusePadAndConv(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
input_op = constant_op.constant(
np.array(inputs), shape=[1, 2, 3, 2], dtype=dtypes.float32)
@@ -300,7 +300,7 @@ class OptimizeForInferenceTest(test.TestCase):
optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
original_graph_def, ["output"])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_ = importer.import_graph_def(
optimized_graph_def, input_map={}, name="optimized")
optimized_result = sess.run(["optimized/output:0"])
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index d8ba13d8d2..3dbccd1409 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -15,7 +15,7 @@
"""Command-line interface to inspect and execute a graph in a SavedModel.
For detailed usages and examples, please refer to:
-https://www.tensorflow.org/guide/saved_model_cli
+https://www.tensorflow.org/guide/saved_model#cli_to_inspect_and_execute_savedmodel
"""
diff --git a/tensorflow/python/training/adagrad.py b/tensorflow/python/training/adagrad.py
index 3508b98475..cc0da26b27 100644
--- a/tensorflow/python/training/adagrad.py
+++ b/tensorflow/python/training/adagrad.py
@@ -34,7 +34,7 @@ class AdagradOptimizer(optimizer.Optimizer):
See this [paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
or this
- [intro](http://cs.stanford.edu/~ppasupat/a9online/uploads/proximal_notes.pdf).
+ [intro](https://ppasupat.github.io/a9online/uploads/proximal_notes.pdf).
"""
def __init__(self, learning_rate, initial_accumulator_value=0.1,
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index 3bd4bd75bd..1efabcd854 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -344,7 +344,7 @@ class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook):
raise ValueError("steps_per_run should be greater than 0")
self._num_steps = num_steps
self._last_step = last_step
- self._steps_per_run = steps_per_run
+ self._steps_per_run_initial_value = steps_per_run
def begin(self):
self._global_step_tensor = training_util.get_global_step()
@@ -353,7 +353,8 @@ class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook):
self._steps_per_run_variable = get_or_create_steps_per_run_variable()
def _update_steps_per_run_variable(self, global_step, session):
- steps = min(self._last_step - global_step, self._steps_per_run)
+ steps = min(self._last_step - global_step,
+ self._steps_per_run_initial_value)
self._steps_per_run_variable.load(steps, session=session)
def after_create_session(self, session, coord):
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index 56c4043d9d..eff15b24ce 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -247,7 +247,7 @@ def _default_getter(name, shape, dtype, initializer=None,
def initial_value():
return initializer(
shape_object.as_list(), dtype=dtype, partition_info=partition_info)
- return variables.Variable(
+ return variables.VariableV1(
initial_value=initial_value,
name=name,
dtype=variable_dtype,
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 21ca1735e0..419a9ec12b 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -195,6 +195,10 @@ class _SameScopeAgainContext(object):
class DistributionStrategy(object):
"""A list of devices with a state & compute distribution policy.
+ See [tensorflow/contrib/distribute/README.md](
+ https://www.tensorflow.org/code/tensorflow/contrib/distribute/README.md)
+ for overview and examples.
+
The intent is that you can write an algorithm in a stylized way and
it will be usable with a variety of different `DistributionStrategy`
implementations. Each descendant will implement a different strategy
diff --git a/tensorflow/python/training/evaluation.py b/tensorflow/python/training/evaluation.py
index b36444a14c..2c4eb02d53 100644
--- a/tensorflow/python/training/evaluation.py
+++ b/tensorflow/python/training/evaluation.py
@@ -18,13 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import time
import math
+import time
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
@@ -77,6 +78,59 @@ def _get_latest_eval_step_value(update_ops):
return array_ops.identity(_get_or_create_eval_step().read_value())
+class _MultiStepStopAfterNEvalsHook(session_run_hook.SessionRunHook):
+ """Run hook used by the evaluation routines to run the `eval_ops` N times."""
+
+ def __init__(self, num_evals, steps_per_run=1):
+ """Constructs the run hook.
+
+ Args:
+ num_evals: The number of evaluations to run for. if set to None, will
+ iterate the dataset until all inputs are exhausted.
+ steps_per_run: Number of steps executed per run call.
+ """
+ self._num_evals = num_evals
+ self._evals_completed = None
+ self._steps_per_run_initial_value = steps_per_run
+
+ def _set_evals_completed_tensor(self, updated_eval_step):
+ self._evals_completed = updated_eval_step
+
+ def begin(self):
+ self._steps_per_run_variable = \
+ basic_session_run_hooks.get_or_create_steps_per_run_variable()
+
+ def after_create_session(self, session, coord):
+ # Update number of steps to run in the first run call
+ if self._num_evals is None:
+ steps = self._steps_per_run_initial_value
+ else:
+ steps = min(self._steps_per_run_initial_value, self._num_evals)
+ self._steps_per_run_variable.load(steps, session=session)
+
+ def before_run(self, run_context):
+ return session_run_hook.SessionRunArgs({
+ 'evals_completed': self._evals_completed
+ })
+
+ def after_run(self, run_context, run_values):
+ evals_completed = run_values.results['evals_completed']
+ # Update number of steps to run in the next iteration
+ if self._num_evals is None:
+ steps = self._steps_per_run_initial_value
+ else:
+ steps = min(self._num_evals - evals_completed,
+ self._steps_per_run_initial_value)
+ self._steps_per_run_variable.load(steps, session=run_context.session)
+
+ if self._num_evals is None:
+ logging.info('Evaluation [%d]', evals_completed)
+ else:
+ logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals)
+ if self._num_evals is not None and evals_completed >= self._num_evals:
+ run_context.request_stop()
+
+
class _StopAfterNEvalsHook(session_run_hook.SessionRunHook):
"""Run hook used by the evaluation routines to run the `eval_ops` N times."""
@@ -176,7 +230,15 @@ def _evaluate_once(checkpoint_path,
hooks = list(hooks or [])
if eval_ops is not None:
- update_eval_step = state_ops.assign_add(eval_step, 1, use_locking=True)
+ if any([isinstance(h, _MultiStepStopAfterNEvalsHook) for h in hooks]):
+ steps_per_run_variable = \
+ basic_session_run_hooks.get_or_create_steps_per_run_variable()
+ update_eval_step = state_ops.assign_add(
+ eval_step,
+ math_ops.cast(steps_per_run_variable, dtype=eval_step.dtype),
+ use_locking=True)
+ else:
+ update_eval_step = state_ops.assign_add(eval_step, 1, use_locking=True)
if isinstance(eval_ops, dict):
eval_ops['update_eval_step'] = update_eval_step
@@ -188,7 +250,7 @@ def _evaluate_once(checkpoint_path,
eval_step_value = _get_latest_eval_step_value(eval_ops)
for h in hooks:
- if isinstance(h, _StopAfterNEvalsHook):
+ if isinstance(h, (_StopAfterNEvalsHook, _MultiStepStopAfterNEvalsHook)):
h._set_evals_completed_tensor(eval_step_value) # pylint: disable=protected-access
logging.info('Starting evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
diff --git a/tensorflow/python/training/ftrl_test.py b/tensorflow/python/training/ftrl_test.py
index 09d6fe36d3..15c50bc878 100644
--- a/tensorflow/python/training/ftrl_test.py
+++ b/tensorflow/python/training/ftrl_test.py
@@ -218,7 +218,7 @@ class FtrlOptimizerTest(test.TestCase):
def testFtrlWithL1_L2_L2ShrinkageSparse(self):
"""Tests the new FTRL op with support for l2 shrinkage on sparse grads."""
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
var1 = variables.Variable([[4.0], [3.0]], dtype=dtype)
grads0 = ops.IndexedSlices(
@@ -252,7 +252,7 @@ class FtrlOptimizerTest(test.TestCase):
def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self):
"""Verifies that l2 shrinkage in FTRL does not change lr schedule."""
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([1.0, 2.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
diff --git a/tensorflow/python/training/gradient_descent_test.py b/tensorflow/python/training/gradient_descent_test.py
index 56d82a5b88..1ddea598e5 100644
--- a/tensorflow/python/training/gradient_descent_test.py
+++ b/tensorflow/python/training/gradient_descent_test.py
@@ -252,12 +252,12 @@ class GradientDescentOptimizerTest(test.TestCase):
optimizer = gradient_descent.GradientDescentOptimizer(1.0)
def step():
- v = resource_variable_ops.ResourceVariable(1.0)
+ self.v = resource_variable_ops.ResourceVariable(1.0)
with backprop.GradientTape() as tape:
- loss = v ** 2
- grad = tape.gradient(loss, v)
- optimizer.apply_gradients([(grad, v)])
- return v.read_value()
+ loss = self.v ** 2
+ grad = tape.gradient(loss, self.v)
+ optimizer.apply_gradients([(grad, self.v)])
+ return self.v.read_value()
compiled_step = function.defun(step)
diff --git a/tensorflow/python/training/learning_rate_decay_test.py b/tensorflow/python/training/learning_rate_decay_test.py
index 5a9215730e..03a32f6ca0 100644
--- a/tensorflow/python/training/learning_rate_decay_test.py
+++ b/tensorflow/python/training/learning_rate_decay_test.py
@@ -63,7 +63,7 @@ class LRDecayTest(test_util.TensorFlowTestCase):
def testVariables(self):
with self.cached_session():
- step = variables.Variable(1)
+ step = variables.VariableV1(1)
assign_1 = step.assign(1)
assign_2 = step.assign(2)
assign_100 = step.assign(100)
@@ -121,7 +121,7 @@ class LRDecayTest(test_util.TensorFlowTestCase):
# Test that ref types are valid.
if not context.executing_eagerly():
- x = variables.Variable(0.0)
+ x = variables.VariableV1(0.0)
x_ref = x.op.outputs[0] # float32_ref tensor should be accepted
boundaries, values = [1.0, 2.0], [1, 2, 3]
learning_rate_decay.piecewise_constant(x_ref, boundaries, values)
diff --git a/tensorflow/python/training/learning_rate_decay_v2_test.py b/tensorflow/python/training/learning_rate_decay_v2_test.py
index 0f2d60dafc..b2ac93f06f 100644
--- a/tensorflow/python/training/learning_rate_decay_v2_test.py
+++ b/tensorflow/python/training/learning_rate_decay_v2_test.py
@@ -62,7 +62,7 @@ class LRDecayTestV2(test_util.TensorFlowTestCase):
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
def testVariables(self):
- with self.test_session():
+ with self.cached_session():
step = variables.Variable(1)
assign_1 = step.assign(1)
assign_2 = step.assign(2)
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index 0e0125a956..82f0e3be52 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -1114,7 +1114,11 @@ class _RecoverableSession(_WrappedSession):
logging.info('An error was raised while a session was being created. '
'This may be due to a preemption of a connected worker '
'or parameter server. A new session will be created. '
- 'Error: %s', e)
+ 'This error may also occur due to a gRPC failure caused '
+ 'by high memory or network bandwidth usage in the '
+ 'parameter servers. If this error occurs repeatedly, try '
+ 'increasing the number of parameter servers assigned to '
+ 'the job. Error: %s', e)
def _check_stop(self):
try:
@@ -1127,7 +1131,11 @@ class _RecoverableSession(_WrappedSession):
'session is complete. This may be due to a preemption in '
'a connected worker or parameter server. The current '
'session will be closed and a new session will be '
- 'created. Error: %s', e)
+ 'created. This error may also occur due to a gRPC failure '
+ 'caused by high memory or network bandwidth usage in the '
+ 'parameter servers. If this error occurs repeatedly, try '
+ 'increasing the number of parameter servers assigned to '
+ 'the job. Error: %s', e)
self.close()
self._sess = self._create_session()
# Since we have just recreated the session, the overall computation should
@@ -1150,7 +1158,11 @@ class _RecoverableSession(_WrappedSession):
logging.info('An error was raised. This may be due to a preemption in '
'a connected worker or parameter server. The current '
'session will be closed and a new session will be '
- 'created. Error: %s', e)
+ 'created. This error may also occur due to a gRPC failure '
+ 'caused by high memory or network bandwidth usage in the '
+ 'parameter servers. If this error occurs repeatedly, try '
+ 'increasing the number of parameter servers assigned to '
+ 'the job. Error: %s', e)
self.close()
self._sess = None
@@ -1166,7 +1178,11 @@ class _RecoverableSession(_WrappedSession):
logging.info('An error was raised. This may be due to a preemption in '
'a connected worker or parameter server. The current '
'session will be closed and a new session will be '
- 'created. Error: %s', e)
+ 'created. This error may also occur due to a gRPC failure '
+ 'caused by high memory or network bandwidth usage in the '
+ 'parameter servers. If this error occurs repeatedly, try '
+ 'increasing the number of parameter servers assigned to '
+ 'the job. Error: %s', e)
self.close()
self._sess = None
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index 2d7799d66a..c870d99de9 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -69,8 +69,8 @@ class ScaffoldTest(test.TestCase):
def test_defaults_empty_graph(self):
with ops.Graph().as_default():
scaffold = monitored_session.Scaffold()
- variables.Variable(1, name='my_var')
- variables.Variable(
+ variables.VariableV1(1, name='my_var')
+ variables.VariableV1(
2, name='my_local_var', collections=[ops.GraphKeys.LOCAL_VARIABLES])
scaffold.finalize()
self.assertTrue(isinstance(scaffold.init_op, ops.Operation))
@@ -105,7 +105,7 @@ class ScaffoldTest(test.TestCase):
def test_caches_values(self):
with ops.Graph().as_default():
- variables.Variable([1])
+ variables.VariableV1([1])
scaffold1 = monitored_session.Scaffold()
scaffold1.finalize()
scaffold2 = monitored_session.Scaffold()
@@ -119,7 +119,7 @@ class ScaffoldTest(test.TestCase):
def test_raise_error_if_more_than_one_cached_item(self):
with ops.Graph().as_default():
- variables.Variable([1])
+ variables.VariableV1([1])
ops.add_to_collection(ops.GraphKeys.SAVERS, saver_lib.Saver())
ops.add_to_collection(ops.GraphKeys.SAVERS, saver_lib.Saver())
with self.assertRaisesRegexp(RuntimeError, 'More than one item'):
@@ -127,7 +127,7 @@ class ScaffoldTest(test.TestCase):
def test_uses_passed_values(self):
with ops.Graph().as_default():
- variables.Variable([1])
+ variables.VariableV1([1])
saver = saver_lib.Saver()
scaffold = monitored_session.Scaffold(
init_op=2,
@@ -148,7 +148,7 @@ class ScaffoldTest(test.TestCase):
def test_graph_is_finalized(self):
with ops.Graph().as_default():
- variables.Variable([1])
+ variables.VariableV1([1])
monitored_session.Scaffold().finalize()
with self.assertRaisesRegexp(RuntimeError,
'Graph is finalized and cannot be modified'):
@@ -157,7 +157,7 @@ class ScaffoldTest(test.TestCase):
def test_new_scaffold_from_default_scaffold(self):
scaffold1 = monitored_session.Scaffold()
with ops.Graph().as_default():
- variables.Variable([1])
+ variables.VariableV1([1])
saver = saver_lib.Saver()
scaffold2 = monitored_session.Scaffold(
init_op=2,
@@ -180,7 +180,7 @@ class ScaffoldTest(test.TestCase):
def test_new_scaffold_from_existing_scaffold(self):
with ops.Graph().as_default():
- variables.Variable([1])
+ variables.VariableV1([1])
saver = saver_lib.Saver()
scaffold1 = monitored_session.Scaffold(
init_op=2,
@@ -1374,7 +1374,7 @@ class MonitoredSessionTest(test.TestCase):
def test_defaults(self):
with ops.Graph().as_default():
- a_var = variables.Variable(0)
+ a_var = variables.VariableV1(0)
with monitored_session.MonitoredSession() as session:
self.assertEqual(0, session.run(a_var))
@@ -1700,7 +1700,7 @@ class MonitoredSessionTest(test.TestCase):
def test_graph_finalized_during_run_unfinalized_after_exit(self):
with ops.Graph().as_default() as g:
- a_var = variables.Variable(0)
+ a_var = variables.VariableV1(0)
with monitored_session.MonitoredSession() as session:
self.assertEqual(0, session.run(a_var))
self.assertTrue(g.finalized)
@@ -1708,7 +1708,7 @@ class MonitoredSessionTest(test.TestCase):
def test_keep_finalized_graph_as_finalized(self):
with ops.Graph().as_default() as g:
- a_var = variables.Variable(0)
+ a_var = variables.VariableV1(0)
monitored_session.Scaffold().finalize()
with monitored_session.MonitoredSession() as session:
self.assertEqual(0, session.run(a_var))
@@ -2032,7 +2032,7 @@ class MonitoredSessionTest(test.TestCase):
with ops.Graph().as_default():
c = array_ops.placeholder(dtypes.float32)
v = array_ops.identity(c)
- graph_state = variables.Variable(0.0)
+ graph_state = variables.VariableV1(0.0)
graph_side_effect = state_ops.assign_add(graph_state, 0.31)
def step_fn(step_context):
@@ -2088,7 +2088,7 @@ class MonitoredSessionTest(test.TestCase):
c = array_ops.placeholder(dtypes.float32)
v = array_ops.identity(c)
vv = constant_op.constant(3.2)
- graph_state = variables.Variable(0.0)
+ graph_state = variables.VariableV1(0.0)
graph_side_effect = state_ops.assign_add(graph_state, 0.31)
class Hook(session_run_hook.SessionRunHook):
@@ -2125,7 +2125,7 @@ class SingularMonitoredSessionTest(test.TestCase):
def test_handles_initialization(self):
with ops.Graph().as_default():
- a_var = variables.Variable(0)
+ a_var = variables.VariableV1(0)
with monitored_session.SingularMonitoredSession() as session:
# If it's not initialized, following statement raises an error.
self.assertEqual(0, session.run(a_var))
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 699162b30c..f004f3944a 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -585,7 +585,7 @@ class Optimizer(
var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
if not var_list:
raise ValueError("No gradients provided for any variable: %s." %
- ([str(v) for _, _, v in converted_grads_and_vars],))
+ ([str(v) for _, v, _ in converted_grads_and_vars],))
with ops.init_scope():
self._create_slots(var_list)
update_ops = []
diff --git a/tensorflow/python/training/quantize_training.i b/tensorflow/python/training/quantize_training.i
index 41e62e0252..1ab600bb22 100644
--- a/tensorflow/python/training/quantize_training.i
+++ b/tensorflow/python/training/quantize_training.i
@@ -55,6 +55,13 @@ PyObject* DoQuantizeTrainingOnGraphDefHelper(
%insert("python") %{
+from tensorflow.python.util import deprecation
+from tensorflow.python.util.tf_export import tf_export
+
+@deprecation.deprecated(
+ None,
+ "GraphDef quantized training rewriter is deprecated in the long term")
+@tf_export(v1=["train.do_quantize_training_on_graphdef"])
def do_quantize_training_on_graphdef(input_graph, num_bits):
"""A general quantization scheme is being developed in `tf.contrib.quantize`.
diff --git a/tensorflow/python/training/quantize_training_test.py b/tensorflow/python/training/quantize_training_test.py
index 9754adea85..6edbf7665f 100644
--- a/tensorflow/python/training/quantize_training_test.py
+++ b/tensorflow/python/training/quantize_training_test.py
@@ -58,7 +58,8 @@ class PywrapQuantizeTrainingTest(test.TestCase):
g = ops.Graph()
with session.Session(graph=g) as sess:
a = constant_op.constant(6.0, shape=[1, 1], name='a')
- b = variables.Variable(constant_op.constant(7.0, shape=[1, 1]), name='b')
+ b = variables.VariableV1(
+ constant_op.constant(7.0, shape=[1, 1]), name='b')
c = math_ops.matmul(a, b, name='matmul')
init_op = variables.global_variables_initializer()
diff --git a/tensorflow/python/training/queue_runner_test.py b/tensorflow/python/training/queue_runner_test.py
index 9b9e28af2b..15fe42bbd8 100644
--- a/tensorflow/python/training/queue_runner_test.py
+++ b/tensorflow/python/training/queue_runner_test.py
@@ -44,7 +44,7 @@ class QueueRunnerTest(test.TestCase):
with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
variables.global_variables_initializer().run()
@@ -64,9 +64,9 @@ class QueueRunnerTest(test.TestCase):
with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var0 = variables.Variable(zero64)
+ var0 = variables.VariableV1(zero64)
count_up_to_3 = var0.count_up_to(3)
- var1 = variables.Variable(zero64)
+ var1 = variables.VariableV1(zero64)
count_up_to_30 = var1.count_up_to(30)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
qr = queue_runner_impl.QueueRunner(queue, [count_up_to_3, count_up_to_30])
@@ -131,7 +131,7 @@ class QueueRunnerTest(test.TestCase):
with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
variables.global_variables_initializer().run()
@@ -184,7 +184,7 @@ class QueueRunnerTest(test.TestCase):
with self.cached_session() as sess:
with session.Session() as other_sess:
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
variables.global_variables_initializer().run()
@@ -199,7 +199,7 @@ class QueueRunnerTest(test.TestCase):
with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
variables.global_variables_initializer().run()
@@ -215,7 +215,7 @@ class QueueRunnerTest(test.TestCase):
with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
variables.global_variables_initializer().run()
@@ -250,7 +250,7 @@ class QueueRunnerTest(test.TestCase):
def testStartQueueRunners(self):
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
init_op = variables.global_variables_initializer()
@@ -267,7 +267,7 @@ class QueueRunnerTest(test.TestCase):
def testStartQueueRunnersRaisesIfNotASession(self):
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
init_op = variables.global_variables_initializer()
@@ -280,7 +280,7 @@ class QueueRunnerTest(test.TestCase):
def testStartQueueRunnersIgnoresMonitoredSession(self):
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
init_op = variables.global_variables_initializer()
@@ -297,7 +297,7 @@ class QueueRunnerTest(test.TestCase):
graph = ops.Graph()
with graph.as_default():
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
init_op = variables.global_variables_initializer()
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 274c856686..5b2b19e913 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -622,6 +622,14 @@ class BaseSaverBuilder(object):
yield BaseSaverBuilder.ResourceVariableSaveable(
variable, variable._save_slice_info.spec, name)
# pylint: enable=protected-access
+ elif isinstance(op, checkpointable.CheckpointableBase) and not isinstance(
+ op, variables.Variable):
+ # pylint: disable=protected-access
+ for attr, factory in op._gather_saveables_for_checkpoint().items():
+ op = (factory(name + "_" + attr) if callable(factory) else factory)
+ for op in BaseSaverBuilder.SaveableObjectsForOp(op, op.name):
+ yield op
+ # pylint: enable=protected-access
else:
# A variable or tensor.
if context.executing_eagerly():
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 0ac84813c8..49e6e6546d 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -311,8 +311,8 @@ class SaverTest(test.TestCase):
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
- v0 = variables.Variable(10.0, name="v0")
- v1 = variables.Variable(20.0, name="v1")
+ v0 = variables.VariableV1(10.0, name="v0")
+ v1 = variables.VariableV1(20.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
v2_init = v2.insert("k1", 30.0)
save = saver_module.Saver(
@@ -350,8 +350,8 @@ class SaverTest(test.TestCase):
# Start a second session. In that session the parameter nodes
# have not been initialized either.
with self.cached_session() as sess:
- v0 = variables.Variable(-1.0, name="v0")
- v1 = variables.Variable(-1.0, name="v1")
+ v0 = variables.VariableV1(-1.0, name="v0")
+ v1 = variables.VariableV1(-1.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})
@@ -370,7 +370,7 @@ class SaverTest(test.TestCase):
self.assertEqual(30.0, v2.values().eval())
def testFilenameTensor(self):
- v0 = variables.Variable(0, name="v0")
+ v0 = variables.VariableV1(0, name="v0")
filename = b"somerandomfilename"
save = saver_module.Saver({"v0": v0}, filename=filename)
with self.cached_session() as sess:
@@ -379,7 +379,7 @@ class SaverTest(test.TestCase):
self.assertEqual(sess.run(tensor), filename)
def testInvalidPath(self):
- v0 = variables.Variable(0, name="v0")
+ v0 = variables.VariableV1(0, name="v0")
for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):
with self.cached_session() as sess:
save = saver_module.Saver({"v0": v0}, write_version=ver)
@@ -392,7 +392,7 @@ class SaverTest(test.TestCase):
with self.cached_session() as sess:
# Build a graph with 1 node, and save and restore for them.
- v = variables.Variable(np.int64(15), name="v")
+ v = variables.VariableV1(np.int64(15), name="v")
save = saver_module.Saver({"v": v}, restore_sequentially=True)
variables.global_variables_initializer().run()
@@ -402,7 +402,7 @@ class SaverTest(test.TestCase):
self.assertEqual(save_path, val)
with self.cached_session() as sess:
- v = variables.Variable(np.int64(-1), name="v")
+ v = variables.VariableV1(np.int64(-1), name="v")
save = saver_module.Saver({"v": v})
with self.assertRaisesWithPredicateMatch(
@@ -416,9 +416,9 @@ class SaverTest(test.TestCase):
def testSomeErrors(self):
with ops_lib.Graph().as_default():
- v0 = variables.Variable([10.0], name="v0")
- v1 = variables.Variable([20.0], name="v1")
- v2 = variables.Variable([20.0], name="v2")
+ v0 = variables.VariableV1([10.0], name="v0")
+ v1 = variables.VariableV1([20.0], name="v1")
+ v2 = variables.VariableV1([20.0], name="v2")
v2._set_save_slice_info(
variables.Variable.SaveSliceInfo("v1", [1], [0], [1]))
@@ -446,7 +446,7 @@ class SaverTest(test.TestCase):
def testSameName(self):
with ops_lib.Graph().as_default():
- v0 = variables.Variable([10.0], name="v0")
+ v0 = variables.VariableV1([10.0], name="v0")
v2 = saver_test_utils.CheckpointedOp(name="v2")
# Saving one variable under two names raises an error.
@@ -468,8 +468,8 @@ class SaverTest(test.TestCase):
with self.session(graph=ops_lib.Graph()) as sess:
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
- v0 = variables.Variable(10.0, name="v0")
- v1 = variables.Variable(20.0, name="v1")
+ v0 = variables.VariableV1(10.0, name="v0")
+ v1 = variables.VariableV1(20.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
v2_init = v2.insert("k1", 30.0)
save = saver_module.Saver([v0, v1, v2.saveable])
@@ -490,8 +490,8 @@ class SaverTest(test.TestCase):
# Start a second session. In that session the variables
# have not been initialized either.
with self.session(graph=ops_lib.Graph()) as sess:
- v0 = variables.Variable(-1.0, name="v0")
- v1 = variables.Variable(-1.0, name="v1")
+ v0 = variables.VariableV1(-1.0, name="v0")
+ v1 = variables.VariableV1(-1.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
save = saver_module.Saver([v0, v1, v2.saveable])
@@ -515,8 +515,8 @@ class SaverTest(test.TestCase):
# Build another graph with 2 nodes, initialized
# differently, and a Restore node for them.
with self.session(graph=ops_lib.Graph()) as sess:
- v0_2 = variables.Variable(1000.0, name="v0")
- v1_2 = variables.Variable(2000.0, name="v1")
+ v0_2 = variables.VariableV1(1000.0, name="v0")
+ v1_2 = variables.VariableV1(2000.0, name="v1")
v2_2 = saver_test_utils.CheckpointedOp(name="v2")
save2 = saver_module.Saver([v0_2, v1_2, v2_2.saveable])
v2_2.insert("k1000", 3000.0).run()
@@ -574,14 +574,14 @@ class SaverTest(test.TestCase):
save_path = os.path.join(self.get_temp_dir(), "gpu")
with session.Session("", graph=ops_lib.Graph()) as sess:
with sess.graph.device(test.gpu_device_name()):
- v0_1 = variables.Variable(123.45)
+ v0_1 = variables.VariableV1(123.45)
save = saver_module.Saver({"v0": v0_1})
variables.global_variables_initializer().run()
save.save(sess, save_path)
with session.Session("", graph=ops_lib.Graph()) as sess:
with sess.graph.device(test.gpu_device_name()):
- v0_2 = variables.Variable(543.21)
+ v0_2 = variables.VariableV1(543.21)
save = saver_module.Saver({"v0": v0_2})
variables.global_variables_initializer().run()
@@ -591,22 +591,22 @@ class SaverTest(test.TestCase):
save_path = os.path.join(self.get_temp_dir(), "gpu")
with session.Session("", graph=ops_lib.Graph()) as sess:
with sess.graph.device(test.gpu_device_name()):
- v0_1 = variables.Variable(123.45)
+ v0_1 = variables.VariableV1(123.45)
save = saver_module.Saver({"v0": v0_1}, sharded=True, allow_empty=True)
variables.global_variables_initializer().run()
save.save(sess, save_path)
with session.Session("", graph=ops_lib.Graph()) as sess:
with sess.graph.device(test.gpu_device_name()):
- v0_2 = variables.Variable(543.21)
+ v0_2 = variables.VariableV1(543.21)
save = saver_module.Saver({"v0": v0_2}, sharded=True, allow_empty=True)
variables.global_variables_initializer().run()
def testVariables(self):
save_path = os.path.join(self.get_temp_dir(), "variables")
with session.Session("", graph=ops_lib.Graph()) as sess:
- one = variables.Variable(1.0)
- twos = variables.Variable([2.0, 2.0, 2.0])
+ one = variables.VariableV1(1.0)
+ twos = variables.VariableV1([2.0, 2.0, 2.0])
v2 = saver_test_utils.CheckpointedOp(name="v2")
init = variables.global_variables_initializer()
save = saver_module.Saver()
@@ -615,8 +615,8 @@ class SaverTest(test.TestCase):
save.save(sess, save_path)
with session.Session("", graph=ops_lib.Graph()) as sess:
- one = variables.Variable(0.0)
- twos = variables.Variable([0.0, 0.0, 0.0])
+ one = variables.VariableV1(0.0)
+ twos = variables.VariableV1([0.0, 0.0, 0.0])
v2 = saver_test_utils.CheckpointedOp(name="v2")
# Saver with no arg, defaults to 'all variables'.
save = saver_module.Saver()
@@ -628,14 +628,14 @@ class SaverTest(test.TestCase):
def testVarListShouldBeEmptyInDeferredBuild(self):
with ops_lib.Graph().as_default():
- v = variables.Variable(1.0)
+ v = variables.VariableV1(1.0)
with self.assertRaisesRegexp(ValueError, "defer_build"):
saver_module.Saver([v], defer_build=True)
def testBuildShouldBeCalledBeforeSaveInCaseOfDeferBuild(self):
save_path = os.path.join(self.get_temp_dir(), "error_deferred_build")
with ops_lib.Graph().as_default(), session.Session() as sess:
- variables.Variable(1.0)
+ variables.VariableV1(1.0)
saver = saver_module.Saver(defer_build=True)
with self.assertRaisesRegexp(RuntimeError, "build"):
saver.save(sess, save_path)
@@ -643,18 +643,18 @@ class SaverTest(test.TestCase):
def testDeferredBuild(self):
save_path = os.path.join(self.get_temp_dir(), "deferred_build")
with session.Session("", graph=ops_lib.Graph()) as sess:
- one = variables.Variable(1.0)
+ one = variables.VariableV1(1.0)
save = saver_module.Saver(defer_build=True)
# if build is not deferred, saver cannot save the `twos`.
- twos = variables.Variable([2.0, 2.0, 2.0])
+ twos = variables.VariableV1([2.0, 2.0, 2.0])
init = variables.global_variables_initializer()
save.build()
init.run()
save.save(sess, save_path)
with session.Session("", graph=ops_lib.Graph()) as sess:
- one = variables.Variable(0.0)
- twos = variables.Variable([0.0, 0.0, 0.0])
+ one = variables.VariableV1(0.0)
+ twos = variables.VariableV1([0.0, 0.0, 0.0])
# Saver with no arg, defaults to 'all variables'.
save = saver_module.Saver()
save.restore(sess, save_path)
@@ -664,7 +664,7 @@ class SaverTest(test.TestCase):
def testReshape(self):
save_path = os.path.join(self.get_temp_dir(), "variables_reshape")
with session.Session("", graph=ops_lib.Graph()) as sess:
- var = variables.Variable([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+ var = variables.VariableV1([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
init = variables.global_variables_initializer()
save = saver_module.Saver()
init.run()
@@ -672,7 +672,7 @@ class SaverTest(test.TestCase):
# Error when restoring with default reshape=False
with session.Session("", graph=ops_lib.Graph()) as sess:
- var = variables.Variable([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
+ var = variables.VariableV1([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
save = saver_module.Saver()
with self.assertRaisesRegexp(
errors_impl.InvalidArgumentError,
@@ -681,7 +681,7 @@ class SaverTest(test.TestCase):
# Restored to new shape with reshape=True
with session.Session("", graph=ops_lib.Graph()) as sess:
- var = variables.Variable([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
+ var = variables.VariableV1([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
save = saver_module.Saver(reshape=True)
save.restore(sess, save_path)
self.assertAllClose([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], var.eval())
@@ -731,8 +731,8 @@ class SaverTest(test.TestCase):
for save_path in paths:
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
- v0 = variables.Variable(10.0, name="v0")
- v1 = variables.Variable(20.0, name="v1")
+ v0 = variables.VariableV1(10.0, name="v0")
+ v1 = variables.VariableV1(20.0, name="v1")
save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
init_all_op = variables.global_variables_initializer()
@@ -770,8 +770,8 @@ class SaverTest(test.TestCase):
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
- v0 = variables.Variable(10.0, name="v0")
- v1 = variables.Variable(20.0, name="v1")
+ v0 = variables.VariableV1(10.0, name="v0")
+ v1 = variables.VariableV1(20.0, name="v1")
save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
init_all_op = variables.global_variables_initializer()
@@ -859,10 +859,10 @@ class SaveRestoreShardedTest(test.TestCase):
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
- v0 = variables.Variable(10, name="v0")
+ v0 = variables.VariableV1(10, name="v0")
t0 = saver_test_utils.CheckpointedOp(name="t0")
with sess.graph.device("/cpu:1"):
- v1 = variables.Variable(20, name="v1")
+ v1 = variables.VariableV1(20, name="v1")
t1 = saver_test_utils.CheckpointedOp(name="t1")
save = saver_module.Saver(
{
@@ -890,7 +890,7 @@ class SaveRestoreShardedTest(test.TestCase):
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
- v0 = variables.Variable(111, name="v0")
+ v0 = variables.VariableV1(111, name="v0")
t0 = saver_test_utils.CheckpointedOp(name="t0")
save = saver_module.Saver(
{
@@ -914,7 +914,7 @@ class SaveRestoreShardedTest(test.TestCase):
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
- v1 = variables.Variable(222)
+ v1 = variables.VariableV1(222)
t1 = saver_test_utils.CheckpointedOp(name="t1")
save = saver_module.Saver(
{
@@ -938,10 +938,10 @@ class SaveRestoreShardedTest(test.TestCase):
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
- v0 = variables.Variable(111, name="v0")
+ v0 = variables.VariableV1(111, name="v0")
t0 = saver_test_utils.CheckpointedOp(name="t0")
with sess.graph.device("/cpu:1"):
- v1 = variables.Variable(222, name="v1")
+ v1 = variables.VariableV1(222, name="v1")
t1 = saver_test_utils.CheckpointedOp(name="t1")
save = saver_module.Saver(
{
@@ -984,7 +984,7 @@ class SaveRestoreShardedTest(test.TestCase):
def testSaverDef(self):
with self.cached_session():
- v0 = variables.Variable(123, name="v0")
+ v0 = variables.VariableV1(123, name="v0")
save = saver_module.Saver({"v0": v0}, sharded=True)
sd = save.as_saver_def()
self.assertTrue(sd.sharded)
@@ -1023,7 +1023,7 @@ class SaveRestoreShardedTest(test.TestCase):
if use_resource:
vs = [resource_variable_ops.ResourceVariable(rnd, name=var_name)]
else:
- vs = [variables.Variable(rnd, name=var_name)]
+ vs = [variables.VariableV1(rnd, name=var_name)]
variables.global_variables_initializer().run()
if call_saver_with_dict:
@@ -1054,7 +1054,7 @@ class SaveRestoreShardedTest(test.TestCase):
]
else:
new_vs = [
- variables.Variable(
+ variables.VariableV1(
array_ops.zeros(
shape=var_full_shape), # != original contents.
name=var_name)
@@ -1210,7 +1210,7 @@ class MaxToKeepTest(test.TestCase):
save_dir = self._get_test_dir("max_to_keep_non_sharded")
with self.cached_session() as sess:
- v = variables.Variable(10.0, name="v")
+ v = variables.VariableV1(10.0, name="v")
save = saver_module.Saver({"v": v}, max_to_keep=2)
variables.global_variables_initializer().run()
self.assertEqual([], save.last_checkpoints)
@@ -1389,9 +1389,9 @@ class MaxToKeepTest(test.TestCase):
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
- v0 = variables.Variable(111, name="v0")
+ v0 = variables.VariableV1(111, name="v0")
with sess.graph.device("/cpu:1"):
- v1 = variables.Variable(222, name="v1")
+ v1 = variables.VariableV1(222, name="v1")
save = saver_module.Saver(
{
"v0": v0,
@@ -1448,7 +1448,7 @@ class MaxToKeepTest(test.TestCase):
save_dir2 = self._get_test_dir("max_to_keep_0")
with self.cached_session() as sess:
- v = variables.Variable(10.0, name="v")
+ v = variables.VariableV1(10.0, name="v")
variables.global_variables_initializer().run()
# Test max_to_keep being None.
@@ -1475,7 +1475,7 @@ class MaxToKeepTest(test.TestCase):
save_dir = self._get_test_dir("no_meta_graph")
with self.cached_session() as sess:
- v = variables.Variable(10.0, name="v")
+ v = variables.VariableV1(10.0, name="v")
save = saver_module.Saver({"v": v})
variables.global_variables_initializer().run()
@@ -1632,13 +1632,13 @@ class MetaGraphTest(test.TestCase):
filename = os.path.join(test_dir, "metafile")
with self.cached_session():
# Creates a graph.
- v0 = variables.Variable(1.0, name="v0")
+ v0 = variables.VariableV1(1.0, name="v0")
control_flow_ops.cond(
math_ops.less(v0, 10), lambda: math_ops.add(v0, 1),
lambda: math_ops.subtract(v0, 1))
control_flow_ops.while_loop(lambda i: math_ops.less(i, 10),
lambda i: math_ops.add(i, 1), [v0])
- var = variables.Variable(constant_op.constant(0, dtype=dtypes.int64))
+ var = variables.VariableV1(constant_op.constant(0, dtype=dtypes.int64))
count_up_to = var.count_up_to(3)
input_queue = data_flow_ops.FIFOQueue(
30, dtypes.float32, shared_name="collection_queue")
@@ -1687,7 +1687,7 @@ class MetaGraphTest(test.TestCase):
def testAddCollectionDefFails(self):
with self.cached_session():
# Creates a graph.
- v0 = variables.Variable(10.0, name="v0")
+ v0 = variables.VariableV1(10.0, name="v0")
# Creates a saver.
save = saver_module.Saver({"v0": v0})
# Generates MetaGraphDef.
@@ -1711,8 +1711,8 @@ class MetaGraphTest(test.TestCase):
saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
with self.session(graph=ops_lib.Graph()) as sess:
# Creates a graph.
- v0 = variables.Variable([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
- v1 = variables.Variable(11.0, name="v1")
+ v0 = variables.VariableV1([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
+ v1 = variables.VariableV1(11.0, name="v1")
# Creates 2 savers.
saver0 = saver_module.Saver({"v0": v0}, name="saver0")
saver1 = saver_module.Saver({"v1": v1}, name="saver1")
@@ -1788,8 +1788,8 @@ class MetaGraphTest(test.TestCase):
saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
with self.session(graph=ops_lib.Graph()) as sess:
# Creates a graph.
- v0 = variables.Variable([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
- v1 = variables.Variable(11.0, name="v1")
+ v0 = variables.VariableV1([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
+ v1 = variables.VariableV1(11.0, name="v1")
# Creates 2 savers.
saver0 = saver_module.Saver({"v0": v0}, name="saver0")
@@ -1840,7 +1840,7 @@ class MetaGraphTest(test.TestCase):
filename = os.path.join(test_dir, "metafile")
with self.session(graph=ops_lib.Graph()):
# Creates a graph.
- variables.Variable(10.0, name="v0")
+ variables.VariableV1(10.0, name="v0")
# Exports the graph as binary format.
saver_module.export_meta_graph(filename, as_text=False)
with self.session(graph=ops_lib.Graph()):
@@ -1871,8 +1871,8 @@ class MetaGraphTest(test.TestCase):
test_dir = self._get_test_dir("slice_saver")
filename = os.path.join(test_dir, "metafile")
with self.cached_session():
- v1 = variables.Variable([20.0], name="v1")
- v2 = variables.Variable([20.0], name="v2")
+ v1 = variables.VariableV1([20.0], name="v1")
+ v2 = variables.VariableV1([20.0], name="v2")
v2._set_save_slice_info(
variables.Variable.SaveSliceInfo("v1", [1], [0], [1]))
@@ -1899,7 +1899,7 @@ class MetaGraphTest(test.TestCase):
# Hidden 1
images = constant_op.constant(1.2, dtypes.float32, shape=[100, 28])
with ops_lib.name_scope("hidden1"):
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.truncated_normal(
[28, 128], stddev=1.0 / math.sqrt(float(28))),
name="weights")
@@ -1907,7 +1907,7 @@ class MetaGraphTest(test.TestCase):
# the save and restore of control flow context (which doesn't make any
# sense here from a machine learning perspective). The typical biases is
# a simple Variable without the conditions.
- biases = variables.Variable(
+ biases = variables.VariableV1(
control_flow_ops.cond(
math_ops.less(random.random(), 0.5),
lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),
@@ -1915,7 +1915,7 @@ class MetaGraphTest(test.TestCase):
hidden1 = nn_ops.relu(math_ops.matmul(images, weights) + biases)
# Hidden 2
with ops_lib.name_scope("hidden2"):
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.truncated_normal(
[128, 32], stddev=1.0 / math.sqrt(float(128))),
name="weights")
@@ -1933,15 +1933,16 @@ class MetaGraphTest(test.TestCase):
_, biases = control_flow_ops.while_loop(
loop_cond, loop_body,
- [constant_op.constant(0), variables.Variable(array_ops.zeros([32]))])
+ [constant_op.constant(0),
+ variables.VariableV1(array_ops.zeros([32]))])
hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases)
# Linear
with ops_lib.name_scope("softmax_linear"):
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.truncated_normal(
[32, 10], stddev=1.0 / math.sqrt(float(32))),
name="weights")
- biases = variables.Variable(array_ops.zeros([10]), name="biases")
+ biases = variables.VariableV1(array_ops.zeros([10]), name="biases")
logits = math_ops.matmul(hidden2, weights) + biases
ops_lib.add_to_collection("logits", logits)
init_all_op = variables.global_variables_initializer()
@@ -2028,7 +2029,7 @@ class MetaGraphTest(test.TestCase):
# Create while loop using `outer_body_fn`.
with ops_lib.Graph().as_default():
- var = variables.Variable(0.0)
+ var = variables.VariableV1(0.0)
var_name = var.name
output = graph_fn(var)
output_name = output.name
@@ -2122,8 +2123,8 @@ class MetaGraphTest(test.TestCase):
def testStrippedOpListDef(self):
with self.cached_session():
# Creates a graph.
- v0 = variables.Variable(0.0)
- var = variables.Variable(10.0)
+ v0 = variables.VariableV1(0.0)
+ var = variables.VariableV1(10.0)
math_ops.add(v0, var)
@function.Defun(dtypes.float32)
@@ -2161,8 +2162,8 @@ class MetaGraphTest(test.TestCase):
# With strip_default_attrs enabled, attributes "T" (float32) and "Tout"
# (complex64) in the "Complex" op must be removed.
with self.cached_session():
- real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
- imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+ real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
+ imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num})
@@ -2178,8 +2179,8 @@ class MetaGraphTest(test.TestCase):
# (complex64) in the "Complex" op must *not* be removed, even if they map
# to their defaults.
with self.session(graph=ops_lib.Graph()):
- real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
- imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+ real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
+ imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num})
@@ -2198,9 +2199,9 @@ class MetaGraphTest(test.TestCase):
image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
with session.Session() as sess:
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.random_uniform([784, 10]), name="weights")
- bias = variables.Variable(array_ops.zeros([10]), name="bias")
+ bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias, name="logits")
nn_ops.softmax(logit, name="prediction")
cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
@@ -2243,7 +2244,7 @@ class MetaGraphTest(test.TestCase):
self.assertIsNone(new_saver_1)
# Create a variable in graph_2 under scope "my_scope".
- variables.Variable(array_ops.zeros([10]), name="my_scope/my_var")
+ variables.VariableV1(array_ops.zeros([10]), name="my_scope/my_var")
sess.run(variables.global_variables_initializer())
# Restore the checkpoint into a different scope "subgraph_2".
new_saver_2 = saver_module.import_meta_graph(
@@ -2268,9 +2269,9 @@ class MetaGraphTest(test.TestCase):
image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
with session.Session() as sess:
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.random_uniform([784, 10]), name="weights")
- bias = variables.Variable(array_ops.zeros([10]), name="bias")
+ bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias, name="logits")
nn_ops.softmax(logit, name="prediction")
cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
@@ -2299,9 +2300,9 @@ class MetaGraphTest(test.TestCase):
with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"):
image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.random_uniform([784, 10]), name="weights")
- bias = variables.Variable(array_ops.zeros([10]), name="bias")
+ bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
nn_ops.softmax(logit, name="prediction")
cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
@@ -2332,9 +2333,9 @@ class MetaGraphTest(test.TestCase):
with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"):
image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.random_uniform([784, 10]), name="weights")
- bias = variables.Variable(array_ops.zeros([10]), name="bias")
+ bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
nn_ops.softmax(logit, name="prediction")
cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
@@ -2385,9 +2386,9 @@ class CheckpointReaderTest(test.TestCase):
def testDebugString(self):
# Builds a graph.
- v0 = variables.Variable(
+ v0 = variables.VariableV1(
[[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
- v1 = variables.Variable(
+ v1 = variables.VariableV1(
[[[1], [2]], [[3], [4]], [[5], [6]]], dtype=dtypes.float32, name="v1")
init_all_op = variables.global_variables_initializer()
save = saver_module.Saver(
@@ -2444,7 +2445,8 @@ class WriteGraphTest(test.TestCase):
def testWriteGraph(self):
test_dir = self._get_test_dir("write_graph_dir")
- variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
+ variables.VariableV1(
+ [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
path = graph_io.write_graph(ops_lib.get_default_graph(),
os.path.join(test_dir, "l1"), "graph.pbtxt")
truth = os.path.join(test_dir, "l1", "graph.pbtxt")
@@ -2453,7 +2455,8 @@ class WriteGraphTest(test.TestCase):
def testRecursiveCreate(self):
test_dir = self._get_test_dir("deep_dir")
- variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
+ variables.VariableV1(
+ [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),
os.path.join(test_dir, "l1", "l2", "l3"),
"graph.pbtxt")
@@ -2477,7 +2480,7 @@ class ScopedGraphTest(test.TestCase):
images = constant_op.constant(
1.2, dtypes.float32, shape=[100, 28], name="images")
with ops_lib.name_scope("hidden1"):
- weights1 = variables.Variable(
+ weights1 = variables.VariableV1(
random_ops.truncated_normal(
[28, 128], stddev=1.0 / math.sqrt(float(28))),
name="weights")
@@ -2485,7 +2488,7 @@ class ScopedGraphTest(test.TestCase):
# coverage the save and restore of control flow context (which doesn't
# make any sense here from a machine learning perspective). The typical
# biases is a simple Variable without the conditions.
- biases1 = variables.Variable(
+ biases1 = variables.VariableV1(
control_flow_ops.cond(
math_ops.less(random.random(), 0.5),
lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),
@@ -2494,7 +2497,7 @@ class ScopedGraphTest(test.TestCase):
# Hidden 2
with ops_lib.name_scope("hidden2"):
- weights2 = variables.Variable(
+ weights2 = variables.VariableV1(
random_ops.truncated_normal(
[128, 32], stddev=1.0 / math.sqrt(float(128))),
name="weights")
@@ -2511,16 +2514,16 @@ class ScopedGraphTest(test.TestCase):
return it + 1, biases2
_, biases2 = control_flow_ops.while_loop(loop_cond, loop_body, [
- constant_op.constant(0), variables.Variable(array_ops.zeros([32]))
+ constant_op.constant(0), variables.VariableV1(array_ops.zeros([32]))
])
hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights2) + biases2)
# Linear
with ops_lib.name_scope("softmax_linear"):
- weights3 = variables.Variable(
+ weights3 = variables.VariableV1(
random_ops.truncated_normal(
[32, 10], stddev=1.0 / math.sqrt(float(32))),
name="weights")
- biases3 = variables.Variable(array_ops.zeros([10]), name="biases")
+ biases3 = variables.VariableV1(array_ops.zeros([10]), name="biases")
logits = math_ops.matmul(hidden2, weights3) + biases3
ops_lib.add_to_collection("logits", logits)
@@ -2566,7 +2569,7 @@ class ScopedGraphTest(test.TestCase):
with graph.as_default():
# Hidden 2
with ops_lib.name_scope("hidden2"):
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.truncated_normal(
[128, 32], stddev=1.0 / math.sqrt(float(128))),
name="weights")
@@ -2583,16 +2586,16 @@ class ScopedGraphTest(test.TestCase):
return it + 1, biases
_, biases = control_flow_ops.while_loop(loop_cond, loop_body, [
- constant_op.constant(0), variables.Variable(array_ops.zeros([32]))
+ constant_op.constant(0), variables.VariableV1(array_ops.zeros([32]))
])
hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases)
# Linear
with ops_lib.name_scope("softmax_linear"):
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.truncated_normal(
[32, 10], stddev=1.0 / math.sqrt(float(32))),
name="weights")
- biases = variables.Variable(array_ops.zeros([10]), name="biases")
+ biases = variables.VariableV1(array_ops.zeros([10]), name="biases")
logits = math_ops.matmul(hidden2, weights) + biases
ops_lib.add_to_collection("logits", logits)
@@ -2629,9 +2632,9 @@ class ScopedGraphTest(test.TestCase):
with ops_lib.name_scope("hidden1"):
images = constant_op.constant(
1.0, dtypes.float32, shape=[3, 2], name="images")
- weights1 = variables.Variable(
+ weights1 = variables.VariableV1(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
- biases1 = variables.Variable([0.1] * 3, name="biases")
+ biases1 = variables.VariableV1([0.1] * 3, name="biases")
nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
# Run the graph and save scoped checkpoint.
@@ -2685,9 +2688,9 @@ class ScopedGraphTest(test.TestCase):
with ops_lib.name_scope("hidden1"):
images = constant_op.constant(
1.0, dtypes.float32, shape=[3, 2], name="images")
- weights1 = variables.Variable(
+ weights1 = variables.VariableV1(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
- biases1 = variables.Variable([0.1] * 3, name="biases")
+ biases1 = variables.VariableV1([0.1] * 3, name="biases")
nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
# Run the graph and save scoped checkpoint.
@@ -2720,12 +2723,12 @@ class ScopedGraphTest(test.TestCase):
graph = ops_lib.Graph()
with graph.as_default():
with ops_lib.name_scope("hidden1"):
- variable1 = variables.Variable([1.0], name="variable1")
+ variable1 = variables.VariableV1([1.0], name="variable1")
saver1 = saver_module.Saver(var_list=[variable1])
graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver1)
with ops_lib.name_scope("hidden2"):
- variable2 = variables.Variable([2.0], name="variable2")
+ variable2 = variables.VariableV1([2.0], name="variable2")
saver2 = saver_module.Saver(var_list=[variable2], name="hidden2/")
graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver2)
@@ -2850,30 +2853,32 @@ class CheckpointableCompatibilityTests(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testNotSaveableButIsCheckpointable(self):
v = _OwnsAVariableSimple()
- saver = saver_module.Saver(var_list=[v])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
- with self.cached_session() as sess:
- self.evaluate(v.non_dep_variable.assign(42.))
- save_path = saver.save(sess, prefix)
- self.evaluate(v.non_dep_variable.assign(43.))
- saver.restore(sess, save_path)
- self.assertEqual(42., self.evaluate(v.non_dep_variable))
+ for saver in (saver_module.Saver(var_list=[v]),
+ saver_module.Saver(var_list={"v": v})):
+ with self.cached_session() as sess:
+ self.evaluate(v.non_dep_variable.assign(42.))
+ save_path = saver.save(sess, prefix)
+ self.evaluate(v.non_dep_variable.assign(43.))
+ saver.restore(sess, save_path)
+ self.assertEqual(42., self.evaluate(v.non_dep_variable))
@test_util.run_in_graph_and_eager_modes
def testMoreComplexSaveableReturned(self):
v = _OwnsMirroredVariables()
- saver = saver_module.Saver(var_list=[v])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
self.evaluate(v.non_dep_variable.assign(42.))
- with self.cached_session() as sess:
- save_path = saver.save(sess, prefix)
- self.evaluate(v.non_dep_variable.assign(43.))
- self.evaluate(v.mirrored.assign(44.))
- saver.restore(sess, save_path)
- self.assertEqual(42., self.evaluate(v.non_dep_variable))
- self.assertEqual(42., self.evaluate(v.mirrored))
+ for saver in (saver_module.Saver(var_list=[v]),
+ saver_module.Saver(var_list={"v": v})):
+ with self.cached_session() as sess:
+ save_path = saver.save(sess, prefix)
+ self.evaluate(v.non_dep_variable.assign(43.))
+ self.evaluate(v.mirrored.assign(44.))
+ saver.restore(sess, save_path)
+ self.assertEqual(42., self.evaluate(v.non_dep_variable))
+ self.assertEqual(42., self.evaluate(v.mirrored))
def testSingleTensorEvaluation(self):
@@ -2976,7 +2981,7 @@ class CheckpointableCompatibilityTests(test.TestCase):
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
with ops_lib.Graph().as_default() as g:
- a = variables.Variable(1., name="a")
+ a = variables.VariableV1(1., name="a")
a_saver = saver_module.Saver([a])
with self.session(graph=g) as sess:
@@ -2984,7 +2989,7 @@ class CheckpointableCompatibilityTests(test.TestCase):
save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
with ops_lib.Graph().as_default() as g:
- a = variables.Variable([1.], name="a")
+ a = variables.VariableV1([1.], name="a")
a_saver = saver_module.Saver([a])
with self.session(graph=g) as sess:
with self.assertRaisesRegexp(
diff --git a/tensorflow/python/training/server_lib_same_variables_no_clear_test.py b/tensorflow/python/training/server_lib_same_variables_no_clear_test.py
index c7e84e9ba1..5aa7f45c2b 100644
--- a/tensorflow/python/training/server_lib_same_variables_no_clear_test.py
+++ b/tensorflow/python/training/server_lib_same_variables_no_clear_test.py
@@ -37,8 +37,8 @@ class SameVariablesNoClearTest(test.TestCase):
server = server_lib.Server.create_local_server()
with session.Session(server.target) as sess_1:
- v0 = variables.Variable([[2, 1]], name="v0")
- v1 = variables.Variable([[1], [2]], name="v1")
+ v0 = variables.VariableV1([[2, 1]], name="v0")
+ v1 = variables.VariableV1([[1], [2]], name="v1")
v2 = math_ops.matmul(v0, v1)
sess_1.run([v0.initializer, v1.initializer])
self.assertAllEqual([[4]], sess_1.run(v2))
diff --git a/tensorflow/python/training/server_lib_test.py b/tensorflow/python/training/server_lib_test.py
index 063044f0d0..cf995707fc 100644
--- a/tensorflow/python/training/server_lib_test.py
+++ b/tensorflow/python/training/server_lib_test.py
@@ -76,9 +76,9 @@ class GrpcServerTest(test.TestCase):
def testResetFails(self):
# Creates variable with container name.
with ops.container("test0"):
- v0 = variables.Variable(1.0, name="v0")
+ v0 = variables.VariableV1(1.0, name="v0")
# Creates variable with default container.
- v1 = variables.Variable(2.0, name="v1")
+ v1 = variables.VariableV1(2.0, name="v1")
# Verifies resetting the non-existent target returns error.
with self.assertRaises(errors_impl.NotFoundError):
session.Session.reset("nonexistent", ["test0"])
@@ -234,8 +234,8 @@ class GrpcServerTest(test.TestCase):
[0.], dtype=dtypes.float32))
self.assertIsNotNone(input_queue)
- var = variables.Variable(1., dtype=dtypes.float32, trainable=False,
- name="var")
+ var = variables.VariableV1(1., dtype=dtypes.float32, trainable=False,
+ name="var")
sess.run(variables.global_variables_initializer())
queue_runner_impl.start_queue_runners(sess)
@@ -245,7 +245,7 @@ class GrpcServerTest(test.TestCase):
server = self._cached_server
init_value = array_ops.placeholder(dtypes.int32)
- v = variables.Variable(init_value, validate_shape=False, name="v")
+ v = variables.VariableV1(init_value, validate_shape=False, name="v")
sharing_config = config_pb2.ConfigProto(isolate_session_state=False)
sharing_sess_0 = session.Session(server.target, config=sharing_config)
@@ -302,7 +302,7 @@ class GrpcServerTest(test.TestCase):
isolate_config = config_pb2.ConfigProto(isolate_session_state=True)
with ops.Graph().as_default():
- w_vector = variables.Variable([1, 2, 3], name="w")
+ w_vector = variables.VariableV1([1, 2, 3], name="w")
with session.Session(server.target, config=sharing_config) as sess:
with self.assertRaises(errors_impl.FailedPreconditionError):
sess.run(w_vector)
@@ -310,20 +310,20 @@ class GrpcServerTest(test.TestCase):
self.assertAllEqual([1, 2, 3], sess.run(w_vector))
with ops.Graph().as_default():
- w_vector = variables.Variable([4, 5, 6], name="w")
+ w_vector = variables.VariableV1([4, 5, 6], name="w")
with session.Session(server.target, config=sharing_config) as sess:
self.assertAllEqual([1, 2, 3], sess.run(w_vector))
sess.run(w_vector.initializer)
self.assertAllEqual([4, 5, 6], sess.run(w_vector))
with ops.Graph().as_default():
- w_scalar = variables.Variable(86, name="w")
+ w_scalar = variables.VariableV1(86, name="w")
with session.Session(server.target, config=sharing_config) as sess:
with self.assertRaises(errors_impl.InvalidArgumentError):
sess.run(w_scalar.initializer)
with ops.Graph().as_default():
- w_scalar = variables.Variable(37, name="w")
+ w_scalar = variables.VariableV1(37, name="w")
with session.Session(server.target, config=isolate_config) as sess:
with self.assertRaises(errors_impl.FailedPreconditionError):
sess.run(w_scalar)
diff --git a/tensorflow/python/training/session_manager_test.py b/tensorflow/python/training/session_manager_test.py
index f1d18f7704..2b5c3b01de 100644
--- a/tensorflow/python/training/session_manager_test.py
+++ b/tensorflow/python/training/session_manager_test.py
@@ -40,7 +40,7 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionSucceeds(self):
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0], name="v")
+ v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
sess = sm.prepare_session(
@@ -50,7 +50,7 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionSucceedsWithInitFeedDict(self):
with ops.Graph().as_default():
p = array_ops.placeholder(dtypes.float32, shape=(3,))
- v = variables.Variable(p, name="v")
+ v = variables.VariableV1(p, name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
sess = sm.prepare_session(
@@ -61,7 +61,7 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionSucceedsWithInitFn(self):
with ops.Graph().as_default():
- v = variables.Variable([125], name="v")
+ v = variables.VariableV1([125], name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
sess = sm.prepare_session(
@@ -79,7 +79,7 @@ class SessionManagerTest(test.TestCase):
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0], name="v")
+ v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
saver = saver_lib.Saver({"v": v})
@@ -97,7 +97,7 @@ class SessionManagerTest(test.TestCase):
# Renames the checkpoint directory.
os.rename(checkpoint_dir, checkpoint_dir2)
gfile.MakeDirs(checkpoint_dir)
- v = variables.Variable([6.0, 7.0, 8.0], name="v")
+ v = variables.VariableV1([6.0, 7.0, 8.0], name="v")
with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
session_manager.SessionManager(
@@ -134,7 +134,7 @@ class SessionManagerTest(test.TestCase):
checkpoint_filename_with_path=None):
# Create a new Graph and SessionManager and recover from a checkpoint.
with ops.Graph().as_default():
- v = variables.Variable(2, name="v")
+ v = variables.VariableV1(2, name="v")
with session_lib.Session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
sm2 = session_manager.SessionManager(
@@ -162,7 +162,7 @@ class SessionManagerTest(test.TestCase):
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
+ v = variables.VariableV1(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
saver = saver_lib.Saver({"v": v})
@@ -186,7 +186,7 @@ class SessionManagerTest(test.TestCase):
def testWaitForSessionReturnsNoneAfterTimeout(self):
with ops.Graph().as_default():
- variables.Variable(1, name="v")
+ variables.VariableV1(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables(),
recovery_wait_secs=1)
@@ -217,7 +217,7 @@ class SessionManagerTest(test.TestCase):
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
+ v = variables.VariableV1(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
saver = saver_lib.Saver({"v": v})
@@ -230,8 +230,8 @@ class SessionManagerTest(test.TestCase):
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
# Create a new Graph and SessionManager and recover.
with ops.Graph().as_default():
- v = variables.Variable(2, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(2, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -275,7 +275,7 @@ class SessionManagerTest(test.TestCase):
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
+ v = variables.VariableV1(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
saver = saver_lib.Saver({"v": v})
@@ -288,8 +288,8 @@ class SessionManagerTest(test.TestCase):
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
# Create a new Graph and SessionManager and recover.
with ops.Graph().as_default():
- v = variables.Variable(2, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(2, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -321,7 +321,7 @@ class SessionManagerTest(test.TestCase):
# local_init_op exactly once, regardless of whether the session was
# successfully recovered.
with ops.Graph().as_default():
- w = variables.Variable(
+ w = variables.VariableV1(
1,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -356,8 +356,8 @@ class SessionManagerTest(test.TestCase):
# Create a new Graph and SessionManager and recover.
with ops.Graph().as_default():
- v = variables.Variable(2, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(2, name="v")
+ w = variables.VariableV1(
1,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -389,8 +389,8 @@ class SessionManagerTest(test.TestCase):
def testWaitForSessionLocalInit(self):
server = server_lib.Server.create_local_server()
with ops.Graph().as_default() as graph:
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -420,8 +420,8 @@ class SessionManagerTest(test.TestCase):
def testWaitForSessionWithReadyForLocalInitOpFailsToReadyLocal(self):
with ops.Graph().as_default() as graph:
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -439,8 +439,8 @@ class SessionManagerTest(test.TestCase):
def testWaitForSessionInsufficientReadyForLocalInitCheck(self):
with ops.Graph().as_default() as graph:
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -456,13 +456,13 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionWithReadyForLocalInitOp(self):
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- x = variables.Variable(
+ x = variables.VariableV1(
3 * v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -495,25 +495,25 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionWithPartialInitOp(self):
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- x = variables.Variable(
+ x = variables.VariableV1(
3 * v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="x")
# TODO(b/70206927): Use ResourceVariables once they are handled properly.
- v_res = variables.Variable(1, name="v_res")
- w_res = variables.Variable(
+ v_res = variables.VariableV1(1, name="v_res")
+ w_res = variables.VariableV1(
v_res,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w_res")
- x_res = variables.Variable(
+ x_res = variables.VariableV1(
3 * v_res,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -565,7 +565,7 @@ class SessionManagerTest(test.TestCase):
# cyclic dependencies.
with ops.Graph().as_default():
i = control_flow_ops.while_loop(lambda i: i < 1, lambda i: i + 1, [0])
- v = variables.Variable(array_ops.identity(i), name="v")
+ v = variables.VariableV1(array_ops.identity(i), name="v")
with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
sm = session_manager.SessionManager(
@@ -579,8 +579,8 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionDidNotInitLocalVariable(self):
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -596,8 +596,8 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionDidNotInitLocalVariableList(self):
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -613,8 +613,8 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionWithReadyNotReadyForLocal(self):
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -634,8 +634,8 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionWithInsufficientReadyForLocalInitCheck(self):
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -656,7 +656,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
def testPrepareSessionSucceeds(self):
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0], name="v")
+ v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
sm = session_manager.SessionManager(
ready_op=variables.assert_variables_initialized())
sess = sm.prepare_session(
@@ -666,7 +666,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
def testPrepareSessionSucceedsWithInitFeedDict(self):
with ops.Graph().as_default():
p = array_ops.placeholder(dtypes.float32, shape=(3,))
- v = variables.Variable(p, name="v")
+ v = variables.VariableV1(p, name="v")
sm = session_manager.SessionManager(
ready_op=variables.assert_variables_initialized())
sess = sm.prepare_session(
@@ -677,7 +677,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
def testPrepareSessionSucceedsWithInitFn(self):
with ops.Graph().as_default():
- v = variables.Variable([125], name="v")
+ v = variables.VariableV1([125], name="v")
sm = session_manager.SessionManager(
ready_op=variables.assert_variables_initialized())
sess = sm.prepare_session(
@@ -695,7 +695,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0], name="v")
+ v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
sm = session_manager.SessionManager(
ready_op=variables.assert_variables_initialized())
saver = saver_lib.Saver({"v": v})
@@ -713,7 +713,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
# Renames the checkpoint directory.
os.rename(checkpoint_dir, checkpoint_dir2)
gfile.MakeDirs(checkpoint_dir)
- v = variables.Variable([6.0, 7.0, 8.0], name="v")
+ v = variables.VariableV1([6.0, 7.0, 8.0], name="v")
with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
session_manager.SessionManager(
@@ -755,7 +755,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
+ v = variables.VariableV1(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.assert_variables_initialized())
saver = saver_lib.Saver({"v": v})
@@ -768,7 +768,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
# Create a new Graph and SessionManager and recover.
with ops.Graph().as_default():
- v = variables.Variable(2, name="v")
+ v = variables.VariableV1(2, name="v")
with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
sm2 = session_manager.SessionManager(
@@ -785,7 +785,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
def testWaitForSessionReturnsNoneAfterTimeout(self):
with ops.Graph().as_default():
- variables.Variable(1, name="v")
+ variables.VariableV1(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.assert_variables_initialized(),
recovery_wait_secs=1)
diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py
index 0755364bbe..a5e626d320 100644
--- a/tensorflow/python/training/supervisor.py
+++ b/tensorflow/python/training/supervisor.py
@@ -242,10 +242,9 @@ class Supervisor(object):
ready_for_local_init_op: 1-D string `Tensor`. This tensor is evaluated by
supervisors in `prepare_or_wait_for_session()` to check if the model is
ready to run the local_init_op.
- The model is considered ready if it returns an empty array. Defaults to
- the tensor returned from
- `tf.report_uninitialized_variables(tf.global_variables())`. If `None`,
- the model is not checked for readiness before running local_init_op.
+ The model is considered ready if it returns an empty array. Defaults to
+ `None`. If `None`, the model is not checked for readiness before running
+ local_init_op.
is_chief: If True, create a chief supervisor in charge of initializing
and restoring the model. If False, create a supervisor that relies
on a chief supervisor for inits and restore.
diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py
index caf6eba3e0..7cd99d8680 100644
--- a/tensorflow/python/training/supervisor_test.py
+++ b/tensorflow/python/training/supervisor_test.py
@@ -423,7 +423,7 @@ class SupervisorTest(test.TestCase):
def testLogdirButExplicitlyNoSummaryWriter(self):
logdir = self._test_dir("explicit_no_summary_writer")
with ops.Graph().as_default():
- variables.Variable([1.0], name="foo")
+ variables.VariableV1([1.0], name="foo")
summary.scalar("c1", constant_op.constant(1))
summary.scalar("c2", constant_op.constant(2))
summary.scalar("c3", constant_op.constant(3))
@@ -491,7 +491,7 @@ class SupervisorTest(test.TestCase):
def testNoLogdirSucceeds(self):
with ops.Graph().as_default():
- variables.Variable([1.0, 2.0, 3.0])
+ variables.VariableV1([1.0, 2.0, 3.0])
sv = supervisor.Supervisor(logdir="", summary_op=None)
sess = sv.prepare_or_wait_for_session("")
sess.close()
@@ -499,7 +499,7 @@ class SupervisorTest(test.TestCase):
def testUseSessionManager(self):
with ops.Graph().as_default():
- variables.Variable([1.0, 2.0, 3.0])
+ variables.VariableV1([1.0, 2.0, 3.0])
sm = session_manager_lib.SessionManager()
# Pass in session_manager. The additional init_op is ignored.
sv = supervisor.Supervisor(logdir="", session_manager=sm)
@@ -508,7 +508,7 @@ class SupervisorTest(test.TestCase):
def testInitOp(self):
logdir = self._test_dir("default_init_op")
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0])
+ v = variables.VariableV1([1.0, 2.0, 3.0])
sv = supervisor.Supervisor(logdir=logdir)
sess = sv.prepare_or_wait_for_session("")
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
@@ -517,7 +517,7 @@ class SupervisorTest(test.TestCase):
def testInitFn(self):
logdir = self._test_dir("default_init_op")
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0])
+ v = variables.VariableV1([1.0, 2.0, 3.0])
def _init_fn(sess):
sess.run(v.initializer)
@@ -531,7 +531,7 @@ class SupervisorTest(test.TestCase):
logdir = self._test_dir("feed_dict_init_op")
with ops.Graph().as_default():
p = array_ops.placeholder(dtypes.float32, shape=(3,))
- v = variables.Variable(p, name="v")
+ v = variables.VariableV1(p, name="v")
sv = supervisor.Supervisor(
logdir=logdir,
init_op=variables.global_variables_initializer(),
@@ -550,10 +550,10 @@ class SupervisorTest(test.TestCase):
g = ops.Graph()
with g.as_default():
with ops.device("/job:local"):
- v = variables.Variable(
+ v = variables.VariableV1(
1, name="default_ready_for_local_init_op_v_" + str(uid))
vadd = v.assign_add(1)
- w = variables.Variable(
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -590,7 +590,7 @@ class SupervisorTest(test.TestCase):
# Create a checkpoint.
with ops.Graph().as_default():
- v = variables.Variable(
+ v = variables.VariableV1(
10.0, name="ready_for_local_init_op_restore_v_" + str(uid))
summary.scalar("ready_for_local_init_op_restore_v_" + str(uid), v)
sv = supervisor.Supervisor(logdir=logdir)
@@ -607,10 +607,10 @@ class SupervisorTest(test.TestCase):
g = ops.Graph()
with g.as_default():
with ops.device("/job:local"):
- v = variables.Variable(
+ v = variables.VariableV1(
1.0, name="ready_for_local_init_op_restore_v_" + str(uid))
vadd = v.assign_add(1)
- w = variables.Variable(
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -642,13 +642,13 @@ class SupervisorTest(test.TestCase):
logdir = self._test_dir("default_local_init_op")
with ops.Graph().as_default():
# A local variable.
- v = variables.Variable(
+ v = variables.VariableV1(
[1.0, 2.0, 3.0],
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES])
# An entity which is initialized through a TABLE_INITIALIZER.
- w = variables.Variable([4, 5, 6], trainable=False, collections=[])
+ w = variables.VariableV1([4, 5, 6], trainable=False, collections=[])
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, w.initializer)
# This shouldn't add a variable to the VARIABLES collection responsible
@@ -668,7 +668,7 @@ class SupervisorTest(test.TestCase):
with ops.Graph().as_default():
with ops.device("/job:localhost"):
# A local variable.
- v = variables.Variable(
+ v = variables.VariableV1(
[1.0, 2.0, 3.0],
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES])
@@ -687,8 +687,8 @@ class SupervisorTest(test.TestCase):
server = server_lib.Server.create_local_server()
logdir = self._test_dir("default_init_op_fails")
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0], name="v")
- variables.Variable([4.0, 5.0, 6.0], name="w")
+ v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
+ variables.VariableV1([4.0, 5.0, 6.0], name="w")
# w will not be initialized.
sv = supervisor.Supervisor(logdir=logdir, init_op=v.initializer)
with self.assertRaisesRegexp(RuntimeError,
@@ -699,11 +699,11 @@ class SupervisorTest(test.TestCase):
server = server_lib.Server.create_local_server()
logdir = self._test_dir("default_init_op_fails_for_local_variable")
with ops.Graph().as_default():
- v = variables.Variable(
+ v = variables.VariableV1(
[1.0, 2.0, 3.0],
name="v",
collections=[ops.GraphKeys.LOCAL_VARIABLES])
- variables.Variable(
+ variables.VariableV1(
[1.0, 2.0, 3.0],
name="w",
collections=[ops.GraphKeys.LOCAL_VARIABLES])
@@ -716,17 +716,17 @@ class SupervisorTest(test.TestCase):
def testSetupFail(self):
logdir = self._test_dir("setup_fail")
with ops.Graph().as_default():
- variables.Variable([1.0, 2.0, 3.0], name="v")
+ variables.VariableV1([1.0, 2.0, 3.0], name="v")
with self.assertRaisesRegexp(ValueError, "must have their device set"):
supervisor.Supervisor(logdir=logdir, is_chief=False)
with ops.Graph().as_default(), ops.device("/job:ps"):
- variables.Variable([1.0, 2.0, 3.0], name="v")
+ variables.VariableV1([1.0, 2.0, 3.0], name="v")
supervisor.Supervisor(logdir=logdir, is_chief=False)
def testDefaultGlobalStep(self):
logdir = self._test_dir("default_global_step")
with ops.Graph().as_default():
- variables.Variable(287, name="global_step")
+ variables.VariableV1(287, name="global_step")
sv = supervisor.Supervisor(logdir=logdir)
sess = sv.prepare_or_wait_for_session("")
self.assertEquals(287, sess.run(sv.global_step))
@@ -735,7 +735,7 @@ class SupervisorTest(test.TestCase):
def testRestoreFromMetaGraph(self):
logdir = self._test_dir("restore_from_meta_graph")
with ops.Graph().as_default():
- variables.Variable(1, name="v0")
+ variables.VariableV1(1, name="v0")
sv = supervisor.Supervisor(logdir=logdir)
sess = sv.prepare_or_wait_for_session("")
filename = sv.saver.save(sess, sv.save_path)
@@ -757,7 +757,7 @@ class SupervisorTest(test.TestCase):
logdir = self._test_dir("standard_services_without_global_step")
# Create a checkpoint.
with ops.Graph().as_default():
- v = variables.Variable([1.0], name="foo")
+ v = variables.VariableV1([1.0], name="foo")
summary.scalar("v", v[0])
sv = supervisor.Supervisor(logdir=logdir)
meta_graph_def = meta_graph.create_meta_graph_def(
@@ -796,7 +796,7 @@ class SupervisorTest(test.TestCase):
self.assertRaises(StopIteration, lambda: next(rr))
# There should be a checkpoint file with the variable "foo"
with ops.Graph().as_default(), self.cached_session() as sess:
- v = variables.Variable([10.10], name="foo")
+ v = variables.VariableV1([10.10], name="foo")
sav = saver_lib.Saver([v])
sav.restore(sess, save_path)
self.assertEqual(1.0, v.eval()[0])
@@ -807,7 +807,7 @@ class SupervisorTest(test.TestCase):
logdir = self._test_dir("standard_services_with_global_step")
# Create a checkpoint.
with ops.Graph().as_default():
- v = variables.Variable([123], name="global_step")
+ v = variables.VariableV1([123], name="global_step")
sv = supervisor.Supervisor(logdir=logdir)
meta_graph_def = meta_graph.create_meta_graph_def(
saver_def=sv.saver.saver_def)
@@ -860,7 +860,7 @@ class SupervisorTest(test.TestCase):
self.assertRaises(StopIteration, lambda: next(rr))
# There should be a checkpoint file with the variable "foo"
with ops.Graph().as_default(), self.cached_session() as sess:
- v = variables.Variable([-12], name="global_step")
+ v = variables.VariableV1([-12], name="global_step")
sav = saver_lib.Saver([v])
sav.restore(sess, save_path)
self.assertEqual(123, v.eval()[0])
diff --git a/tensorflow/python/training/sync_replicas_optimizer_test.py b/tensorflow/python/training/sync_replicas_optimizer_test.py
index fff17402e2..1ef8756ef6 100644
--- a/tensorflow/python/training/sync_replicas_optimizer_test.py
+++ b/tensorflow/python/training/sync_replicas_optimizer_test.py
@@ -40,11 +40,12 @@ def get_workers(num_workers, replicas_to_aggregate, workers):
is_chief = (worker_id == 0)
with graph.as_default():
with ops.device("/job:ps/task:0"):
- global_step = variables.Variable(0, name="global_step", trainable=False)
- var_0 = variables.Variable(0.0, name="v0")
+ global_step = variables.VariableV1(
+ 0, name="global_step", trainable=False)
+ var_0 = variables.VariableV1(0.0, name="v0")
with ops.device("/job:ps/task:1"):
- var_1 = variables.Variable(1.0, name="v1")
- var_sparse = variables.Variable([[3.0], [4.0]], name="v_sparse")
+ var_1 = variables.VariableV1(1.0, name="v1")
+ var_sparse = variables.VariableV1([[3.0], [4.0]], name="v_sparse")
with ops.device("/job:worker/task:" + str(worker_id)):
grads_0 = constant_op.constant(0.1 + worker_id * 0.2)
@@ -272,8 +273,8 @@ class SyncReplicasOptimizerHookTest(test.TestCase):
replicas_to_aggregate=1,
total_num_replicas=1)
hook = opt.make_session_run_hook(True)
- v = variables.Variable([0.])
- global_step = variables.Variable(0, name="global_step", trainable=False)
+ v = variables.VariableV1([0.])
+ global_step = variables.VariableV1(0, name="global_step", trainable=False)
opt.minimize(v, global_step=global_step)
hook.begin()
@@ -282,8 +283,8 @@ class SyncReplicasOptimizerHookTest(test.TestCase):
opt=adam.AdamOptimizer(0.01),
replicas_to_aggregate=1,
total_num_replicas=1)
- v = variables.Variable([0.], name="fetch_variable_test")
- global_step = variables.Variable(0, name="global_step", trainable=False)
+ v = variables.VariableV1([0.], name="fetch_variable_test")
+ global_step = variables.VariableV1(0, name="global_step", trainable=False)
opt.minimize(v, global_step=global_step)
opt_variables = opt.variables()
beta1_power, beta2_power = opt._opt._get_beta_accumulators()
diff --git a/tensorflow/python/training/training_ops_test.py b/tensorflow/python/training/training_ops_test.py
index d131a11067..f410ceaaff 100644
--- a/tensorflow/python/training/training_ops_test.py
+++ b/tensorflow/python/training/training_ops_test.py
@@ -51,7 +51,7 @@ class TrainingOpsTest(TensorFlowTestCase):
def _testTypes(self, x, alpha, delta, use_gpu=None):
self.setUp()
with self.test_session(use_gpu=use_gpu):
- var = variables.Variable(x)
+ var = variables.VariableV1(x)
variables.global_variables_initializer().run()
self.assertAllCloseAccordingToType(x, var.eval())
apply_sgd = training_ops.apply_gradient_descent(var, alpha, delta)
@@ -70,8 +70,8 @@ class TrainingOpsTest(TensorFlowTestCase):
def _testTypesForAdagrad(self, x, y, lr, grad, use_gpu=None):
self.setUp()
with self.test_session(use_gpu=use_gpu):
- var = variables.Variable(x)
- accum = variables.Variable(y)
+ var = variables.VariableV1(x)
+ accum = variables.VariableV1(y)
variables.global_variables_initializer().run()
self.assertAllCloseAccordingToType(x, var.eval())
@@ -94,9 +94,9 @@ class TrainingOpsTest(TensorFlowTestCase):
lr_power=-0.5):
self.setUp()
with self.test_session(use_gpu=use_gpu):
- var = variables.Variable(x)
- accum = variables.Variable(y)
- linear = variables.Variable(z)
+ var = variables.VariableV1(x)
+ accum = variables.VariableV1(y)
+ linear = variables.VariableV1(z)
variables.global_variables_initializer().run()
self.assertAllCloseAccordingToType(x, var.eval())
@@ -148,8 +148,8 @@ class TrainingOpsTest(TensorFlowTestCase):
def _testTypesForSparseAdagrad(self, x, y, lr, grad, indices):
self.setUp()
with self.test_session(use_gpu=False):
- var = variables.Variable(x)
- accum = variables.Variable(y)
+ var = variables.VariableV1(x)
+ accum = variables.VariableV1(y)
variables.global_variables_initializer().run()
self.assertAllCloseAccordingToType(x, var.eval())
@@ -178,9 +178,9 @@ class TrainingOpsTest(TensorFlowTestCase):
lr_power=-0.5):
self.setUp()
with self.test_session(use_gpu=False):
- var = variables.Variable(x)
- accum = variables.Variable(y)
- linear = variables.Variable(z)
+ var = variables.VariableV1(x)
+ accum = variables.VariableV1(y)
+ linear = variables.VariableV1(z)
variables.global_variables_initializer().run()
self.assertAllCloseAccordingToType(x, var.eval())
@@ -257,9 +257,9 @@ class TrainingOpsTest(TensorFlowTestCase):
def _testTypesForAdam(self, var, m, v, grad, use_gpu):
self.setUp()
with self.test_session(use_gpu=use_gpu):
- var_t = variables.Variable(var)
- m_t = variables.Variable(m)
- v_t = variables.Variable(v)
+ var_t = variables.VariableV1(var)
+ m_t = variables.VariableV1(m)
+ v_t = variables.VariableV1(v)
t = 1
beta1 = np.array(0.9, dtype=var.dtype)
@@ -270,8 +270,8 @@ class TrainingOpsTest(TensorFlowTestCase):
epsilon = np.array(1e-8, dtype=var.dtype)
beta1_t = constant_op.constant(beta1, self._toType(var.dtype), [])
beta2_t = constant_op.constant(beta2, self._toType(var.dtype), [])
- beta1_power_t = variables.Variable(beta1_power)
- beta2_power_t = variables.Variable(beta2_power)
+ beta1_power_t = variables.VariableV1(beta1_power)
+ beta2_power_t = variables.VariableV1(beta2_power)
lr_t = constant_op.constant(lr, self._toType(var.dtype), [])
epsilon_t = constant_op.constant(epsilon, self._toType(var.dtype), [])
variables.global_variables_initializer().run()
diff --git a/tensorflow/python/training/training_util_test.py b/tensorflow/python/training/training_util_test.py
index 6cc177e0e8..ba64e785ac 100644
--- a/tensorflow/python/training/training_util_test.py
+++ b/tensorflow/python/training/training_util_test.py
@@ -49,7 +49,7 @@ class GlobalStepTest(test.TestCase):
def test_invalid_shape(self):
with ops.Graph().as_default() as g:
self.assertIsNone(training_util.get_global_step())
- variables.Variable(
+ variables.VariableV1(
[0],
trainable=False,
dtype=dtypes.int32,
@@ -73,7 +73,7 @@ class GlobalStepTest(test.TestCase):
def test_get_global_step(self):
with ops.Graph().as_default() as g:
self.assertIsNone(training_util.get_global_step())
- variables.Variable(
+ variables.VariableV1(
0,
trainable=False,
dtype=dtypes.int32,
diff --git a/tensorflow/python/training/warm_starting_util_test.py b/tensorflow/python/training/warm_starting_util_test.py
index 6c860cd452..3eddf79e34 100644
--- a/tensorflow/python/training/warm_starting_util_test.py
+++ b/tensorflow/python/training/warm_starting_util_test.py
@@ -203,7 +203,7 @@ class WarmStartingUtilTest(test.TestCase):
"new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
fruit_output_layer = variable_scope.get_variable(
"fruit_output_layer",
initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
@@ -279,7 +279,7 @@ class WarmStartingUtilTest(test.TestCase):
"new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
fruit_output_layer = variable_scope.get_variable(
"fruit_output_layer",
initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
@@ -337,7 +337,7 @@ class WarmStartingUtilTest(test.TestCase):
"new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
fruit_output_layer = variable_scope.get_variable(
"fruit_output_layer",
shape=[4, 3],
@@ -403,7 +403,7 @@ class WarmStartingUtilTest(test.TestCase):
"new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
fruit_output_layer = variable_scope.get_variable(
"fruit_output_layer",
shape=[4, 3],
diff --git a/tensorflow/python/util/function_utils.py b/tensorflow/python/util/function_utils.py
index 4e9b07e20a..a56dfbff8e 100644
--- a/tensorflow/python/util/function_utils.py
+++ b/tensorflow/python/util/function_utils.py
@@ -59,6 +59,29 @@ def fn_args(fn):
return tuple(args)
+def has_kwargs(fn):
+ """Returns whether the passed callable has **kwargs in its signature.
+
+ Args:
+ fn: Function, or function-like object (e.g., result of `functools.partial`).
+
+ Returns:
+ `bool`: if `fn` has **kwargs in its signature.
+
+ Raises:
+ `TypeError`: If fn is not a Function, or function-like object.
+ """
+ if isinstance(fn, functools.partial):
+ fn = fn.func
+ elif _is_callable_object(fn):
+ fn = fn.__call__
+ elif not callable(fn):
+ raise TypeError(
+ 'fn should be a function-like object, but is of type {}.'.format(
+ type(fn)))
+ return tf_inspect.getfullargspec(fn).varkw is not None
+
+
def get_func_name(func):
"""Returns name of passed callable."""
_, func = tf_decorator.unwrap(func)
diff --git a/tensorflow/python/util/function_utils_test.py b/tensorflow/python/util/function_utils_test.py
index 1588328c26..e5b0843e4b 100644
--- a/tensorflow/python/util/function_utils_test.py
+++ b/tensorflow/python/util/function_utils_test.py
@@ -135,6 +135,101 @@ class FnArgsTest(test.TestCase):
self.assertEqual(3, double_wrapped_fn(a=3))
+class HasKwargsTest(test.TestCase):
+
+ def test_simple_function(self):
+
+ fn_has_kwargs = lambda **x: x
+ self.assertTrue(function_utils.has_kwargs(fn_has_kwargs))
+
+ fn_has_no_kwargs = lambda x: x
+ self.assertFalse(function_utils.has_kwargs(fn_has_no_kwargs))
+
+ def test_callable(self):
+
+ class FooHasKwargs(object):
+
+ def __call__(self, **x):
+ del x
+ self.assertTrue(function_utils.has_kwargs(FooHasKwargs()))
+
+ class FooHasNoKwargs(object):
+
+ def __call__(self, x):
+ del x
+ self.assertFalse(function_utils.has_kwargs(FooHasNoKwargs()))
+
+ def test_bounded_method(self):
+
+ class FooHasKwargs(object):
+
+ def fn(self, **x):
+ del x
+ self.assertTrue(function_utils.has_kwargs(FooHasKwargs().fn))
+
+ class FooHasNoKwargs(object):
+
+ def fn(self, x):
+ del x
+ self.assertFalse(function_utils.has_kwargs(FooHasNoKwargs().fn))
+
+ def test_partial_function(self):
+ expected_test_arg = 123
+
+ def fn_has_kwargs(test_arg, **x):
+ if test_arg != expected_test_arg:
+ return ValueError('partial fn does not work correctly')
+ return x
+
+ wrapped_fn = functools.partial(fn_has_kwargs, test_arg=123)
+ self.assertTrue(function_utils.has_kwargs(wrapped_fn))
+ some_kwargs = dict(x=1, y=2, z=3)
+ self.assertEqual(wrapped_fn(**some_kwargs), some_kwargs)
+
+ def fn_has_no_kwargs(x, test_arg):
+ if test_arg != expected_test_arg:
+ return ValueError('partial fn does not work correctly')
+ return x
+
+ wrapped_fn = functools.partial(fn_has_no_kwargs, test_arg=123)
+ self.assertFalse(function_utils.has_kwargs(wrapped_fn))
+ some_arg = 1
+ self.assertEqual(wrapped_fn(some_arg), some_arg)
+
+ def test_double_partial(self):
+ expected_test_arg1 = 123
+ expected_test_arg2 = 456
+
+ def fn_has_kwargs(test_arg1, test_arg2, **x):
+ if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
+ return ValueError('partial does not work correctly')
+ return x
+
+ wrapped_fn = functools.partial(fn_has_kwargs, test_arg2=456)
+ double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
+
+ self.assertTrue(function_utils.has_kwargs(double_wrapped_fn))
+ some_kwargs = dict(x=1, y=2, z=3)
+ self.assertEqual(double_wrapped_fn(**some_kwargs), some_kwargs)
+
+ def fn_has_no_kwargs(x, test_arg1, test_arg2):
+ if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
+ return ValueError('partial does not work correctly')
+ return x
+
+ wrapped_fn = functools.partial(fn_has_no_kwargs, test_arg2=456)
+ double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
+
+ self.assertFalse(function_utils.has_kwargs(double_wrapped_fn))
+ some_arg = 1
+ self.assertEqual(double_wrapped_fn(some_arg), some_arg)
+
+ def test_raises_type_error(self):
+ with self.assertRaisesRegexp(
+ TypeError, 'fn should be a function-like object'):
+ function_utils.has_kwargs('not a function')
+
+
class GetFuncNameTest(test.TestCase):
def testWithSimpleFunction(self):
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 2968ca9c07..758cba7487 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -19,6 +19,9 @@ This module can perform operations on nested structures. A nested structure is a
Python sequence, tuple (including `namedtuple`), or dict that can contain
further sequences, tuples, and dicts.
+attr.s decorated classes (http://www.attrs.org) are also supported, in the
+same way as `namedtuple`.
+
The utilities here assume (and do not check) that the nested structures form a
'tree', i.e., no references in the structure of the input of these functions
should be recursive.
@@ -38,6 +41,12 @@ import six as _six
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
+def _get_attrs_values(obj):
+ """Returns the list of values from an attrs instance."""
+ attrs = getattr(obj.__class__, "__attrs_attrs__")
+ return [getattr(obj, a.name) for a in attrs]
+
+
def _sorted(dict_):
"""Returns a sorted list of the dict keys, with error if keys not sortable."""
try:
@@ -64,6 +73,7 @@ def _is_namedtuple(instance, strict=False):
# See the swig file (util.i) for documentation.
_is_mapping = _pywrap_tensorflow.IsMapping
+_is_attrs = _pywrap_tensorflow.IsAttrs
def _sequence_like(instance, args):
@@ -85,7 +95,7 @@ def _sequence_like(instance, args):
# corresponding `OrderedDict` to pack it back).
result = dict(zip(_sorted(instance), args))
return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
- elif _is_namedtuple(instance):
+ elif _is_namedtuple(instance) or _is_attrs(instance):
return type(instance)(*args)
else:
# Not a namedtuple
@@ -93,6 +103,7 @@ def _sequence_like(instance, args):
def _yield_value(iterable):
+ """Yields the next value from the given iterable."""
if _is_mapping(iterable):
# Iterate through dictionaries in a deterministic order by sorting the
# keys. Notice this means that we ignore the original order of `OrderedDict`
@@ -101,6 +112,9 @@ def _yield_value(iterable):
# corresponding `OrderedDict` to pack it back).
for key in _sorted(iterable):
yield iterable[key]
+ elif _is_attrs(iterable):
+ for value in _get_attrs_values(iterable):
+ yield value
else:
for value in iterable:
yield value
@@ -118,6 +132,18 @@ flatten = _pywrap_tensorflow.Flatten
_same_namedtuples = _pywrap_tensorflow.SameNamedtuples
+class _DotString(object):
+
+ def __str__(self):
+ return "."
+
+ def __repr__(self):
+ return "."
+
+
+_DOT = _DotString()
+
+
def assert_same_structure(nest1, nest2, check_types=True):
"""Asserts that two structures are nested in the same way.
@@ -149,7 +175,15 @@ def assert_same_structure(nest1, nest2, check_types=True):
TypeError: If the two structures differ in the type of sequence in any of
their substructures. Only possible if `check_types` is `True`.
"""
- _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types)
+ try:
+ _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types)
+ except (ValueError, TypeError) as e:
+ str1 = str(map_structure(lambda _: _DOT, nest1))
+ str2 = str(map_structure(lambda _: _DOT, nest2))
+ raise type(e)("%s\n"
+ "Entire first structure:\n%s\n"
+ "Entire second structure:\n%s"
+ % (str(e), str1, str2))
def flatten_dict_items(dictionary):
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index ef503137d1..e03a8daaa1 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -33,6 +33,11 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.util import nest
+try:
+ import attr # pylint:disable=g-import-not-at-top
+except ImportError:
+ attr = None
+
class _CustomMapping(collections.Mapping):
@@ -53,6 +58,35 @@ class NestTest(parameterized.TestCase, test.TestCase):
PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name
+ if attr:
+ class BadAttr(object):
+ """Class that has a non-iterable __attrs_attrs__."""
+ __attrs_attrs__ = None
+
+ @attr.s
+ class SampleAttr(object):
+ field1 = attr.ib()
+ field2 = attr.ib()
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testAttrsFlattenAndPack(self):
+ if attr is None:
+ self.skipTest("attr module is unavailable.")
+
+ field_values = [1, 2]
+ sample_attr = NestTest.SampleAttr(*field_values)
+ self.assertFalse(nest._is_attrs(field_values))
+ self.assertTrue(nest._is_attrs(sample_attr))
+ flat = nest.flatten(sample_attr)
+ self.assertEqual(field_values, flat)
+ restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
+ self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
+ self.assertEqual(restructured_from_flat, sample_attr)
+
+ # Check that flatten fails if attributes are not iterable
+ with self.assertRaisesRegexp(TypeError, "object is not iterable"):
+ flat = nest.flatten(NestTest.BadAttr())
+
@test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenAndPack(self):
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
@@ -264,7 +298,11 @@ class NestTest(parameterized.TestCase, test.TestCase):
"Second structure:.*\n\n"
"More specifically: Substructure "
r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
- 'substructure "type=str str=spam" is not')):
+ 'substructure "type=str str=spam" is not\n'
+ "Entire first structure:\n"
+ r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n"
+ "Entire second structure:\n"
+ r"\(\., \.\)")):
nest.assert_same_structure(structure1, structure_different_num_elements)
with self.assertRaisesRegexp(
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index 562bbdcfeb..38b8491c66 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -15,9 +15,11 @@ limitations under the License.
#include "tensorflow/python/util/util.h"
#include <functional>
+#include <memory>
#include <unordered_map>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@@ -190,6 +192,19 @@ int IsMappingHelper(PyObject* o) {
return check_cache->CachedLookup(o);
}
+// Returns 1 if `o` is an instance of attrs-decorated class.
+// Returns 0 otherwise.
+int IsAttrsHelper(PyObject* o) {
+ Safe_PyObjectPtr cls(PyObject_GetAttrString(o, "__class__"));
+ if (cls) {
+ return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
+ } else {
+ // PyObject_GetAttrString returns null on error
+ PyErr_Clear();
+ return 0;
+ }
+}
+
// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
// Returns 0 otherwise.
// Returns -1 if an error occurred.
@@ -204,6 +219,7 @@ int IsSequenceHelper(PyObject* o) {
});
// We treat dicts and other mappings as special cases of sequences.
if (IsMappingHelper(o)) return true;
+ if (IsAttrsHelper(o)) return true;
if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
LOG(WARNING) << "Sets are not currently considered sequences, "
"but this may change in the future, "
@@ -222,93 +238,168 @@ int IsSequenceHelper(PyObject* o) {
return check_cache->CachedLookup(o);
}
-// Implements the same idea as tensorflow.util.nest._yield_value
-// During construction we check if the iterable is a dictionary.
-// If so, we construct a sequence from its sorted keys that will be used
-// for iteration.
-// If not, we construct a sequence directly from the iterable.
-// At each step, we get the next element from the sequence and use it
-// either as a key or return it directly.
-//
-// 'iterable' must not be modified while ValIterator is used.
-class ValIterator {
+// ValueIterator interface
+class ValueIterator {
public:
- explicit ValIterator(PyObject* iterable)
- : dict_(nullptr),
- mapping_(nullptr),
- last_mapping_element_(nullptr),
- seq_(nullptr),
- index_(0) {
- if (PyDict_Check(iterable)) {
- dict_ = iterable;
- // PyDict_Keys returns a list, which can be used with
- // PySequence_Fast_GET_ITEM.
- seq_ = PyDict_Keys(iterable);
- // Iterate through dictionaries in a deterministic order by sorting the
- // keys. Notice this means that we ignore the original order of
- // `OrderedDict` instances. This is intentional, to avoid potential
- // bugs caused by mixing ordered and plain dicts (e.g., flattening
- // a dict but using a corresponding `OrderedDict` to pack it back).
- PyList_Sort(seq_);
- } else if (IsMappingHelper(iterable)) {
- mapping_ = iterable;
- seq_ = MappingKeys(iterable);
- PyList_Sort(seq_);
+ virtual ~ValueIterator() {}
+ virtual Safe_PyObjectPtr next() = 0;
+
+ bool valid() const { return is_valid_; }
+
+ protected:
+ void invalidate() { is_valid_ = false; }
+
+ private:
+ bool is_valid_ = true;
+};
+
+using ValueIteratorPtr = std::unique_ptr<ValueIterator>;
+
+// Iterate through dictionaries in a deterministic order by sorting the
+// keys. Notice this means that we ignore the original order of
+// `OrderedDict` instances. This is intentional, to avoid potential
+// bugs caused by mixing ordered and plain dicts (e.g., flattening
+// a dict but using a corresponding `OrderedDict` to pack it back).
+class DictValueIterator : public ValueIterator {
+ public:
+ explicit DictValueIterator(PyObject* dict)
+ : dict_(dict), keys_(PyDict_Keys(dict)) {
+ if (PyList_Sort(keys_.get()) == -1) {
+ invalidate();
} else {
- seq_ = PySequence_Fast(iterable, "");
+ iter_.reset(PyObject_GetIter(keys_.get()));
}
- size_ = PySequence_Fast_GET_SIZE(seq_);
}
- ~ValIterator() { Py_DECREF(seq_); }
-
- // Return a borrowed reference to the next element from iterable.
- // Return nullptr when iteration is over.
- PyObject* next() {
- if (TF_PREDICT_FALSE(seq_ == nullptr)) {
- return nullptr;
- }
- PyObject* element = nullptr;
- if (index_ < size_) {
- // Both PySequence_Fast_GET_ITEM and PyDict_GetItem return borrowed
- // references. For general mappings, ValIterator keeps a reference to the
- // last retrieved element (and decrefs it before producing the next
- // element) to abstract away the borrowed/new difference.
- element = PySequence_Fast_GET_ITEM(seq_, index_);
- ++index_;
- if (dict_ != nullptr) {
- element = PyDict_GetItem(dict_, element);
- if (element == nullptr) {
- PyErr_SetString(PyExc_RuntimeError,
- "Dictionary was modified during iteration over it");
- return nullptr;
- }
- } else if (mapping_ != nullptr) {
- element = PyObject_GetItem(mapping_, element);
- if (element == nullptr) {
- PyErr_SetString(PyExc_RuntimeError,
- "Mapping was modified during iteration over it");
- return nullptr;
- }
- last_mapping_element_.reset(element);
+ Safe_PyObjectPtr next() override {
+ Safe_PyObjectPtr result;
+ Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
+ if (key) {
+ // PyDict_GetItem returns a borrowed reference.
+ PyObject* elem = PyDict_GetItem(dict_, key.get());
+ if (elem) {
+ Py_INCREF(elem);
+ result.reset(elem);
+ } else {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Dictionary was modified during iteration over it");
}
}
- return element;
+ return result;
}
private:
- // Special casing for things that pass PyDict_Check (faster, no Python calls)
PyObject* dict_;
+ Safe_PyObjectPtr keys_;
+ Safe_PyObjectPtr iter_;
+};
- // General mappings which have custom Python logic
+// Iterate over mapping objects by sorting the keys first
+class MappingValueIterator : public ValueIterator {
+ public:
+ explicit MappingValueIterator(PyObject* mapping)
+ : mapping_(mapping), keys_(MappingKeys(mapping)) {
+ if (!keys_ || PyList_Sort(keys_.get()) == -1) {
+ invalidate();
+ } else {
+ iter_.reset(PyObject_GetIter(keys_.get()));
+ }
+ }
+
+ Safe_PyObjectPtr next() override {
+ Safe_PyObjectPtr result;
+ Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
+ if (key) {
+ // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
+ PyObject* elem = PyObject_GetItem(mapping_, key.get());
+ if (elem) {
+ result.reset(elem);
+ } else {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Mapping was modified during iteration over it");
+ }
+ }
+ return result;
+ }
+
+ private:
PyObject* mapping_;
- Safe_PyObjectPtr last_mapping_element_;
+ Safe_PyObjectPtr keys_;
+ Safe_PyObjectPtr iter_;
+};
+
+// Iterate over a sequence, by index.
+class SequenceValueIterator : public ValueIterator {
+ public:
+ explicit SequenceValueIterator(PyObject* iterable)
+ : seq_(PySequence_Fast(iterable, "")),
+ size_(PySequence_Fast_GET_SIZE(seq_.get())),
+ index_(0) {}
+
+ Safe_PyObjectPtr next() override {
+ Safe_PyObjectPtr result;
+ if (index_ < size_) {
+ // PySequence_Fast_GET_ITEM returns a borrowed reference.
+ PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_);
+ ++index_;
+ Py_INCREF(elem);
+ result.reset(elem);
+ }
- PyObject* seq_;
- Py_ssize_t size_;
+ return result;
+ }
+
+ private:
+ Safe_PyObjectPtr seq_;
+ const Py_ssize_t size_;
Py_ssize_t index_;
};
+// Just return itself as a single item.
+class SparseTensorValueIterator : public ValueIterator {
+ public:
+ explicit SparseTensorValueIterator(PyObject* tensor) : tensor_(tensor) {
+ Py_INCREF(tensor);
+ }
+
+ Safe_PyObjectPtr next() override { return std::move(tensor_); }
+
+ private:
+ Safe_PyObjectPtr tensor_;
+};
+
+class AttrsValueIterator : public ValueIterator {
+ public:
+ explicit AttrsValueIterator(PyObject* nested) : nested_(nested) {
+ Py_INCREF(nested);
+ cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__"));
+ if (cls_) {
+ attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__"));
+ if (attrs_) {
+ iter_.reset(PyObject_GetIter(attrs_.get()));
+ }
+ }
+ if (!iter_ || PyErr_Occurred()) invalidate();
+ }
+
+ Safe_PyObjectPtr next() override {
+ Safe_PyObjectPtr result;
+ Safe_PyObjectPtr item(PyIter_Next(iter_.get()));
+ if (item) {
+ Safe_PyObjectPtr name(PyObject_GetAttrString(item.get(), "name"));
+ result.reset(PyObject_GetAttr(nested_.get(), name.get()));
+ }
+
+ return result;
+ }
+
+ private:
+ Safe_PyObjectPtr nested_;
+ Safe_PyObjectPtr cls_;
+ Safe_PyObjectPtr attrs_;
+ Safe_PyObjectPtr iter_;
+};
+
bool IsSparseTensorValueType(PyObject* o) {
if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) {
return false;
@@ -322,93 +413,37 @@ int IsSequenceForDataHelper(PyObject* o) {
!IsSparseTensorValueType(o);
}
-bool GetNextValuesForDict(PyObject* nested,
- std::vector<Safe_PyObjectPtr>* next_values) {
- Safe_PyObjectPtr keys(PyDict_Keys(nested));
- if (PyList_Sort(keys.get()) == -1) return false;
- Py_ssize_t size = PyList_Size(keys.get());
- for (Py_ssize_t i = 0; i < size; ++i) {
- // We know that key and item will not be deleted because nested owns
- // a reference to them and callers of flatten must not modify nested
- // while the method is running.
- PyObject* key = PyList_GET_ITEM(keys.get(), i);
- PyObject* item = PyDict_GetItem(nested, key);
- Py_INCREF(item);
- next_values->emplace_back(item);
- }
- return true;
-}
-
-bool GetNextValuesForMapping(PyObject* nested,
- std::vector<Safe_PyObjectPtr>* next_values) {
- Safe_PyObjectPtr keys(MappingKeys(nested));
- if (keys.get() == nullptr) {
- return false;
- }
- if (PyList_Sort(keys.get()) == -1) return false;
- Py_ssize_t size = PyList_Size(keys.get());
- for (Py_ssize_t i = 0; i < size; ++i) {
- PyObject* key = PyList_GET_ITEM(keys.get(), i);
- // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
- PyObject* item = PyObject_GetItem(nested, key);
- next_values->emplace_back(item);
- }
- return true;
-}
-
-bool GetNextValuesForIterable(PyObject* nested,
- std::vector<Safe_PyObjectPtr>* next_values) {
- PyObject* item;
- PyObject* iterator = PyObject_GetIter(nested);
- if (iterator == nullptr || PyErr_Occurred()) {
- return false;
- }
- while ((item = PyIter_Next(iterator)) != nullptr) {
- next_values->emplace_back(item);
- }
- Py_DECREF(iterator);
- return true;
-}
-
-// GetNextValues returns the values that the FlattenHelper function will recurse
-// over next.
-bool GetNextValues(PyObject* nested,
- std::vector<Safe_PyObjectPtr>* next_values) {
+ValueIteratorPtr GetValueIterator(PyObject* nested) {
if (PyDict_Check(nested)) {
- // if nested is dictionary, sort it by key and recurse on each value
- return GetNextValuesForDict(nested, next_values);
+ return absl::make_unique<DictValueIterator>(nested);
} else if (IsMappingHelper(nested)) {
- // same treatment as dictionaries, but for custom mapping types
- return GetNextValuesForMapping(nested, next_values);
+ return absl::make_unique<MappingValueIterator>(nested);
+ } else if (IsAttrsHelper(nested)) {
+ return absl::make_unique<AttrsValueIterator>(nested);
+ } else {
+ return absl::make_unique<SequenceValueIterator>(nested);
}
- // iterate and recurse
- return GetNextValuesForIterable(nested, next_values);
}
-// Similar to above, just specialized for the functions in the data pacakage.
-bool GetNextValuesForData(PyObject* nested,
- std::vector<Safe_PyObjectPtr>* next_values) {
+// Similar to above, just specialized for the functions in the data package.
+ValueIteratorPtr GetValueIteratorForData(PyObject* nested) {
if (PyDict_Check(nested)) {
- // if nested is dictionary, sort it by key and recurse on each value
- return GetNextValuesForDict(nested, next_values);
+ return absl::make_unique<DictValueIterator>(nested);
} else if (IsMappingHelper(nested)) {
- // same treatment as dictionaries, but for custom mapping types
- return GetNextValuesForMapping(nested, next_values);
+ return absl::make_unique<MappingValueIterator>(nested);
+ } else if (IsAttrsHelper(nested)) {
+ return absl::make_unique<AttrsValueIterator>(nested);
} else if (IsSparseTensorValueType(nested)) {
- // if nested is a SparseTensorValue, just return itself as a single item
- Py_INCREF(nested);
- next_values->emplace_back(nested);
- return true;
+ return absl::make_unique<SparseTensorValueIterator>(nested);
+ } else {
+ return absl::make_unique<SequenceValueIterator>(nested);
}
- // iterate and recurse
- return GetNextValuesForIterable(nested, next_values);
}
bool FlattenHelper(
PyObject* nested, PyObject* list,
const std::function<int(PyObject*)>& is_sequence_helper,
- const std::function<bool(PyObject*, std::vector<Safe_PyObjectPtr>*)>&
- next_values_getter) {
+ const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter) {
// if nested is not a sequence, append itself and exit
int is_seq = is_sequence_helper(nested);
if (is_seq == -1) return false;
@@ -416,16 +451,15 @@ bool FlattenHelper(
return PyList_Append(list, nested) != -1;
}
- std::vector<Safe_PyObjectPtr> next_values;
- // Get the next values to recurse over.
- if (!next_values_getter(nested, &next_values)) return false;
+ ValueIteratorPtr iter = value_iterator_getter(nested);
+ if (!iter->valid()) return false;
- for (const auto& item : next_values) {
+ for (Safe_PyObjectPtr item = iter->next(); item; item = iter->next()) {
if (Py_EnterRecursiveCall(" in flatten")) {
return false;
}
- const bool success =
- FlattenHelper(item.get(), list, is_sequence_helper, next_values_getter);
+ const bool success = FlattenHelper(item.get(), list, is_sequence_helper,
+ value_iterator_getter);
Py_LeaveRecursiveCall();
if (!success) {
return false;
@@ -579,22 +613,25 @@ bool AssertSameStructureHelper(
}
}
- ValIterator iter1(o1);
- ValIterator iter2(o2);
+ ValueIteratorPtr iter1 = GetValueIterator(o1);
+ ValueIteratorPtr iter2 = GetValueIterator(o2);
+
+ if (!iter1->valid() || !iter2->valid()) return false;
while (true) {
- PyObject* v1 = iter1.next();
- PyObject* v2 = iter2.next();
- if (v1 != nullptr && v2 != nullptr) {
+ Safe_PyObjectPtr v1 = iter1->next();
+ Safe_PyObjectPtr v2 = iter2->next();
+ if (v1 && v2) {
if (Py_EnterRecursiveCall(" in assert_same_structure")) {
return false;
}
- bool no_internal_errors = AssertSameStructureHelper(
- v1, v2, check_types, error_msg, is_type_error, is_sequence_helper);
+ bool no_internal_errors =
+ AssertSameStructureHelper(v1.get(), v2.get(), check_types, error_msg,
+ is_type_error, is_sequence_helper);
Py_LeaveRecursiveCall();
if (!no_internal_errors) return false;
if (!error_msg->empty()) return true;
- } else if (v1 == nullptr && v2 == nullptr) {
+ } else if (!v1 && !v2) {
// Done with all recursive calls. Structure matched.
return true;
} else {
@@ -652,10 +689,11 @@ void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) {
bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
+bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
PyObject* Flatten(PyObject* nested) {
PyObject* list = PyList_New(0);
- if (FlattenHelper(nested, list, IsSequenceHelper, GetNextValues)) {
+ if (FlattenHelper(nested, list, IsSequenceHelper, GetValueIterator)) {
return list;
} else {
Py_DECREF(list);
@@ -668,7 +706,7 @@ bool IsSequenceForData(PyObject* o) { return IsSequenceForDataHelper(o) == 1; }
PyObject* FlattenForData(PyObject* nested) {
PyObject* list = PyList_New(0);
if (FlattenHelper(nested, list, IsSequenceForDataHelper,
- GetNextValuesForData)) {
+ GetValueIteratorForData)) {
return list;
} else {
Py_DECREF(list);
diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h
index 343605285e..01f85ea1dc 100644
--- a/tensorflow/python/util/util.h
+++ b/tensorflow/python/util/util.h
@@ -56,6 +56,15 @@ PyObject* IsNamedtuple(PyObject* o, bool strict);
// True if the sequence subclasses mapping.
bool IsMapping(PyObject* o);
+// Returns a true if its input is an instance of an attr.s decorated class.
+//
+// Args:
+// o: the input to be checked.
+//
+// Returns:
+// True if the object is an instance of an attr.s decorated class.
+bool IsAttrs(PyObject* o);
+
// Implements the same interface as tensorflow.util.nest._same_namedtuples
// Returns Py_True iff the two namedtuples have the same name and fields.
// Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have
diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i
index 104a615636..32a6e684fa 100644
--- a/tensorflow/python/util/util.i
+++ b/tensorflow/python/util/util.i
@@ -65,6 +65,18 @@ Returns:
%unignore tensorflow::swig::IsMapping;
%noexception tensorflow::swig::IsMapping;
+%feature("docstring") tensorflow::swig::IsAttrs
+"""Returns True iff `instance` is an instance of an `attr.s` decorated class.
+
+Args:
+ instance: An instance of a Python object.
+
+Returns:
+ True if `instance` is an instance of an `attr.s` decorated class.
+"""
+%unignore tensorflow::swig::IsAttrs;
+%noexception tensorflow::swig::IsAttrs;
+
%feature("docstring") tensorflow::swig::SameNamedtuples
"Returns True if the two namedtuples have the same name and fields."
%unignore tensorflow::swig::SameNamedtuples;
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 3c533c7f99..ca90c383f9 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/error.h"
#include "tensorflow/stream_executor/lib/initialize.h"
+#include "tensorflow/stream_executor/lib/mathutil.h"
#include "tensorflow/stream_executor/lib/strcat.h"
#include "tensorflow/stream_executor/lib/stringpiece.h"
#include "tensorflow/stream_executor/lib/threadpool.h"
@@ -132,23 +133,42 @@ string ToString(cudnnStatus_t status) {
}
template <typename T>
-cudnnDataType_t GetCudnnDataType();
+cudnnDataType_t GetCudnnDataType(
+ dnn::DataLayout = dnn::DataLayout::kBatchDepthYX);
template <>
-cudnnDataType_t GetCudnnDataType<double>() {
+cudnnDataType_t GetCudnnDataType<double>(dnn::DataLayout) {
return CUDNN_DATA_DOUBLE;
}
template <>
-cudnnDataType_t GetCudnnDataType<float>() {
+cudnnDataType_t GetCudnnDataType<float>(dnn::DataLayout) {
return CUDNN_DATA_FLOAT;
}
template <>
-cudnnDataType_t GetCudnnDataType<Eigen::half>() {
+cudnnDataType_t GetCudnnDataType<Eigen::half>(dnn::DataLayout) {
return CUDNN_DATA_HALF;
}
+template <>
+cudnnDataType_t GetCudnnDataType<int8>(dnn::DataLayout layout) {
+ switch (layout) {
+ case dnn::DataLayout::kYXDepthBatch:
+ case dnn::DataLayout::kYXBatchDepth:
+ case dnn::DataLayout::kBatchYXDepth:
+ case dnn::DataLayout::kBatchDepthYX:
+ return CUDNN_DATA_INT8;
+ case dnn::DataLayout::kBatchDepthYX4:
+ return CUDNN_DATA_INT8x4;
+ }
+}
+
+template <>
+cudnnDataType_t GetCudnnDataType<int32>(dnn::DataLayout) {
+ return CUDNN_DATA_INT32;
+}
+
// RAII wrapper for all calls to cuDNN with a cuDNN handle argument.
//
// See CudnnAccess::GetHandle() for details.
@@ -2387,6 +2407,33 @@ cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
}
}
+// Determines whether we can safely perform a winograd non-fused convolution for
+// the given input and output shapes. This works around b/68264959, an integer
+// overflow in cuDNNv5 and cuDNNv6.
+#if CUDNN_VERSION >= 7000
+bool ShouldIncludeWinogradNonfusedAlgo(const dnn::BatchDescriptor&,
+ const dnn::BatchDescriptor&) {
+ return true;
+}
+#else
+bool ShouldIncludeWinogradNonfusedAlgo(
+ const dnn::BatchDescriptor& input_desc,
+ const dnn::BatchDescriptor& output_desc) {
+ int64 batch = input_desc.count();
+ int64 in_depths = input_desc.feature_map_count();
+ int64 in_rows = input_desc.height();
+ int64 in_cols = input_desc.ndims() == 1 ? 1 : input_desc.width();
+ int64 out_depths = output_desc.feature_map_count();
+
+ int64 total_size = port::MathUtil::CeilOfRatio(batch, int64{16}) *
+ std::max(in_depths, out_depths) * in_cols * in_rows *
+ sizeof(float);
+
+ const int64 threshold = 1L << 31;
+ return total_size < threshold;
+}
+#endif
+
} // namespace
template <class T>
@@ -2465,6 +2512,13 @@ port::Status CudnnSupport::DoConvolveImpl(
return port::Status::OK();
}());
+ if (algo_desc.algo_id() == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
+ !ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor)) {
+ return port::Status(port::error::FAILED_PRECONDITION,
+ "This configuration has potential integer overflow in "
+ "cuDNNv5 and cuDNNv6. See b/68264959.");
+ }
+
RETURN_IF_CUDNN_ERROR(cudnnConvolutionForward(
cudnn.handle(),
/*alpha=*/alpha, /*srcDesc=*/input_nd.handle(),
@@ -2486,19 +2540,19 @@ port::Status CudnnSupport::DoConvolveImpl(
return port::Status::OK();
}
-template <typename Type, typename BiasType, typename ScaleType,
- int cudnn_data_type, int cudnn_compute_type>
+template <typename AccumulatorType, typename ElementType, typename BiasType,
+ typename ScaleType>
port::Status CudnnSupport::DoFusedConvolveImpl(
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
- const DeviceMemory<Type>& conv_input_data, ScaleType conv_input_scale,
- const dnn::FilterDescriptor& filter_descriptor,
- const DeviceMemory<Type>& filter_data,
+ const DeviceMemory<ElementType>& conv_input_data,
+ ScaleType conv_input_scale, const dnn::FilterDescriptor& filter_descriptor,
+ const DeviceMemory<ElementType>& filter_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
- const DeviceMemory<Type>& side_input_data, ScaleType side_input_scale,
- const dnn::BatchDescriptor& bias_descriptor,
+ const DeviceMemory<ElementType>& side_input_data,
+ ScaleType side_input_scale, const dnn::BatchDescriptor& bias_descriptor,
const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
const dnn::BatchDescriptor& output_descriptor,
- DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator,
+ DeviceMemory<ElementType>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
if (activation_mode != dnn::ActivationMode::kRelu &&
@@ -2509,14 +2563,17 @@ port::Status CudnnSupport::DoFusedConvolveImpl(
}
CudnnTensorDescriptor conv_input_nd(
- conv_input_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
+ conv_input_descriptor,
+ GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
CudnnTensorDescriptor output_nd(
- output_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
- CudnnFilterDescriptor filter(filter_descriptor,
- static_cast<cudnnDataType_t>(cudnn_data_type));
- CudnnTensorDescriptor bias_nd(bias_descriptor, CUDNN_DATA_FLOAT);
- CudnnConvolutionDescriptor conv(
- convolution_descriptor, static_cast<cudnnDataType_t>(cudnn_compute_type));
+ output_descriptor,
+ GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
+ CudnnFilterDescriptor filter(
+ filter_descriptor,
+ GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
+ CudnnTensorDescriptor bias_nd(bias_descriptor, GetCudnnDataType<BiasType>());
+ CudnnConvolutionDescriptor conv(convolution_descriptor,
+ GetCudnnDataType<AccumulatorType>());
auto cudnn = cudnn_->GetHandle(parent_, stream);
@@ -2566,6 +2623,14 @@ port::Status CudnnSupport::DoFusedConvolveImpl(
<< "\noutput_nd.handle() = " << output_nd.handle()
<< "\noutput_data->opaque() = " << output_data->opaque();
+ if (algo_desc.algo_id() == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
+ !ShouldIncludeWinogradNonfusedAlgo(conv_input_descriptor,
+ output_descriptor)) {
+ return port::Status(port::error::FAILED_PRECONDITION,
+ "This configuration has potential integer overflow in "
+ "cuDNNv5 and cuDNNv6. See around b/68264959.");
+ }
+
RETURN_IF_CUDNN_ERROR(cudnnConvolutionBiasActivationForward(
cudnn.handle(),
/*alpha1=*/&conv_input_scale,
@@ -2933,8 +2998,7 @@ bool CudnnSupport::DoFusedConvolve(
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
return IsStatusOk(
- DoFusedConvolveImpl<double, double, double, CUDNN_DATA_DOUBLE,
- CUDNN_DATA_DOUBLE>(
+ DoFusedConvolveImpl<double>(
stream, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,
@@ -2957,8 +3021,7 @@ bool CudnnSupport::DoFusedConvolve(
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
return IsStatusOk(
- DoFusedConvolveImpl<float, float, float, CUDNN_DATA_FLOAT,
- CUDNN_DATA_FLOAT>(
+ DoFusedConvolveImpl<float>(
stream, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,
@@ -2982,8 +3045,7 @@ bool CudnnSupport::DoFusedConvolve(
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
return IsStatusOk(
- DoFusedConvolveImpl<Eigen::half, Eigen::half, float, CUDNN_DATA_HALF,
- CUDNN_DATA_FLOAT>(
+ DoFusedConvolveImpl<float>(
stream, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,
@@ -3014,8 +3076,7 @@ bool CudnnSupport::DoFusedConvolve(
return false;
}
return IsStatusOk(
- DoFusedConvolveImpl<int8, float, float, CUDNN_DATA_INT8x4,
- CUDNN_DATA_INT32>(
+ DoFusedConvolveImpl<int32>(
stream, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,
@@ -3096,6 +3157,13 @@ port::Status CudnnSupport::DoConvolveBackwardDataImpl(
}
}
+ if (algo_desc.algo_id() == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
+ !ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor)) {
+ return port::Status(port::error::FAILED_PRECONDITION,
+ "This configuration has potential integer overflow in "
+ "cuDNNv5 and cuDNNv6. See b/68264959.");
+ }
+
// Cudnn 7.1.4 has a bug if the workspace of the following convolution is not
// zero-initialized, nvbugs/2254619.
if (CUDNN_VERSION >= 7000 &&
@@ -3275,6 +3343,13 @@ port::Status CudnnSupport::DoConvolveBackwardFilterImpl(
"This configuration potentially produces incorrect results.");
}());
+ if (algo_desc.algo_id() == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
+ !ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor)) {
+ return port::Status(port::error::FAILED_PRECONDITION,
+ "This configuration has potential integer overflow in "
+ "cuDNNv5 and cuDNNv6. See b/68264959.");
+ }
+
// Zero out the result buffer for strided conv backward filter for NHWC
// layouts. cuDNN 7.1.4 and 7.2 has non-determinisic bug if the buffer is not
// zeroed.
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index 9d88f971bb..74f6f935b8 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -674,19 +674,21 @@ class CudnnSupport : public dnn::DnnSupport {
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result);
- template <typename Type, typename BiasType, typename ScaleType,
- int cudnn_data_type, int cudnn_compute_type>
+ template <typename AccumulatorType, typename ElementType, typename BiasType,
+ typename ScaleType>
port::Status DoFusedConvolveImpl(
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
- const DeviceMemory<Type>& conv_input_data, ScaleType conv_input_scale,
+ const DeviceMemory<ElementType>& conv_input_data,
+ ScaleType conv_input_scale,
const dnn::FilterDescriptor& filter_descriptor,
- const DeviceMemory<Type>& filter_data,
+ const DeviceMemory<ElementType>& filter_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
- const DeviceMemory<Type>& side_input_data, ScaleType side_input_scale,
- const dnn::BatchDescriptor& bias_descriptor,
+ const DeviceMemory<ElementType>& side_input_data,
+ ScaleType side_input_scale, const dnn::BatchDescriptor& bias_descriptor,
const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
const dnn::BatchDescriptor& output_descriptor,
- DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator,
+ DeviceMemory<ElementType>* output_data,
+ ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result);
diff --git a/tensorflow/stream_executor/device_description.h b/tensorflow/stream_executor/device_description.h
index 7f99d81ef3..a4580d6462 100644
--- a/tensorflow/stream_executor/device_description.h
+++ b/tensorflow/stream_executor/device_description.h
@@ -22,8 +22,7 @@ limitations under the License.
#include <map>
#include <memory>
-#include "tensorflow/stream_executor/platform/port.h"
-
+#include "absl/base/macros.h"
#include "tensorflow/stream_executor/launch_dim.h"
#include "tensorflow/stream_executor/platform/port.h"
@@ -359,9 +358,8 @@ class DeviceDescriptionBuilder {
bool ThreadDimOk(const DeviceDescription &device_description,
const ThreadDim &thread_dim);
-// [deprecated] Use MathUtil::CeilOfRatio directly instead.
-//
// Equivalent to ceil(double(element_count) / threads_per_block).
+ABSL_DEPRECATED("Use MathUtil::CeilOfRatio directly instead.")
uint64 DivideCeil(uint64 x, uint64 y);
// Calculate the number of threads/blocks required to process element_count
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 9abfa1db6a..621b155240 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -873,7 +873,7 @@ class NormalizeDescriptor {
// Describes a kind of non-linearity (threshold-like mathematical function).
enum class ActivationMode {
- kNone,
+ kNone = 0,
kSigmoid,
// Rectified linear activation: f(x) = x < 0 ? 0 : x
kRelu,
@@ -885,6 +885,8 @@ enum class ActivationMode {
kTanh,
// Like ReluX, but passes all values in the range [-X,X].
kBandPass,
+
+ kNumActivationModes, // Always in the end.
};
// Returns a string representation of the given activation mode.
diff --git a/tensorflow/stream_executor/lib/array_slice.h b/tensorflow/stream_executor/lib/array_slice.h
index 8e3c4ca047..5f4e586762 100644
--- a/tensorflow/stream_executor/lib/array_slice.h
+++ b/tensorflow/stream_executor/lib/array_slice.h
@@ -16,13 +16,15 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_ARRAY_SLICE_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_ARRAY_SLICE_H_
-#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "absl/types/span.h"
namespace stream_executor {
namespace port {
-using tensorflow::gtl::ArraySlice;
-using tensorflow::gtl::MutableArraySlice;
+template <typename T>
+using ArraySlice = absl::Span<const T>;
+template <typename T>
+using MutableArraySlice = absl::Span<T>;
} // namespace port
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/lib/inlined_vector.h b/tensorflow/stream_executor/lib/inlined_vector.h
index 40bdddb180..0198947e5b 100644
--- a/tensorflow/stream_executor/lib/inlined_vector.h
+++ b/tensorflow/stream_executor/lib/inlined_vector.h
@@ -16,12 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_INLINED_VECTOR_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_INLINED_VECTOR_H_
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "absl/container/inlined_vector.h"
namespace stream_executor {
namespace port {
-using tensorflow::gtl::InlinedVector;
+using absl::InlinedVector;
} // namespace port
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/lib/strcat.h b/tensorflow/stream_executor/lib/strcat.h
index c959e4df5b..3688d7b4eb 100644
--- a/tensorflow/stream_executor/lib/strcat.h
+++ b/tensorflow/stream_executor/lib/strcat.h
@@ -18,13 +18,13 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_
-#include "tensorflow/core/lib/strings/strcat.h"
+#include "absl/strings/str_cat.h"
namespace stream_executor {
namespace port {
-using tensorflow::strings::StrCat;
-using tensorflow::strings::StrAppend;
+using absl::StrAppend;
+using absl::StrCat;
} // namespace port
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/lib/stringpiece.h b/tensorflow/stream_executor/lib/stringpiece.h
index b80de5df30..7624910129 100644
--- a/tensorflow/stream_executor/lib/stringpiece.h
+++ b/tensorflow/stream_executor/lib/stringpiece.h
@@ -16,13 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/stream_executor/platform/port.h"
+#include "absl/strings/string_view.h"
namespace stream_executor {
namespace port {
-using tensorflow::StringPiece;
+using StringPiece = absl::string_view;
} // namespace port
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/plugin_registry.h b/tensorflow/stream_executor/plugin_registry.h
index 49628ecd24..3065b5cb77 100644
--- a/tensorflow/stream_executor/plugin_registry.h
+++ b/tensorflow/stream_executor/plugin_registry.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <map>
+#include "absl/base/macros.h"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/dnn.h"
#include "tensorflow/stream_executor/fft.h"
@@ -97,6 +98,7 @@ class PluginRegistry {
// TODO(b/22689637): Deprecated/temporary. Will be deleted once all users are
// on MultiPlatformManager / PlatformId.
template <typename FactoryT>
+ ABSL_DEPRECATED("Use MultiPlatformManager / PlatformId instead.")
port::StatusOr<FactoryT> GetFactory(PlatformKind platform_kind,
PluginId plugin_id);
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 19d3b2389a..69558fd14b 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -587,6 +587,44 @@ Stream &Stream::ThenConvolveWithScratch(
Stream &Stream::ThenFusedConvolveWithAlgorithm(
const dnn::BatchDescriptor &conv_input_descriptor,
+ const DeviceMemory<double> &conv_input_data, double conv_input_scale,
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<double> &filter_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const DeviceMemory<double> &side_input_data, double side_input_scale,
+ const dnn::BatchDescriptor &bias_descriptor,
+ const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
+ const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
+ ScratchAllocator *scratch_allocator,
+ const dnn::AlgorithmConfig &algorithm_config,
+ dnn::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
+ PARAM(conv_input_scale), PARAM(filter_descriptor),
+ PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
+ PARAM(side_input_data), PARAM(side_input_scale),
+ PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
+ PARAM(algorithm_config));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ auto status = dnn->DoFusedConvolve(
+ this, conv_input_descriptor, conv_input_data, conv_input_scale,
+ filter_descriptor, filter_data, convolution_descriptor,
+ side_input_data, side_input_scale, bias_descriptor, biases,
+ activation_mode, output_descriptor, output, scratch_allocator,
+ algorithm_config, output_profile_result);
+ if (!status && !output_profile_result) {
+ SetError();
+ }
+ } else {
+ SetErrorAndLogNoDnnSupport();
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenFusedConvolveWithAlgorithm(
+ const dnn::BatchDescriptor &conv_input_descriptor,
const DeviceMemory<float> &conv_input_data, float conv_input_scale,
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<float> &filter_data,
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index d04025b681..4a8a270afa 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <tuple>
#include <vector>
+#include "absl/base/macros.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/lib/strcat.h"
@@ -81,8 +82,8 @@ class StreamExecutor {
port::Status Init();
port::Status Init(int device_ordinal, DeviceOptions device_options);
- // DEPRECATED: Do not use; use platform() instead.
// Returns the platform that this StreamExecutor is acting upon.
+ ABSL_DEPRECATED("Use platform() instead.")
PlatformKind platform_kind() const { return platform_kind_; }
// Returns a reference to the platform that created this executor.
@@ -255,15 +256,15 @@ class StreamExecutor {
// [deprecated] Blocks the caller while a data segment of the given size is
// copied from the host source to the device destination.
- //
- // Deprecation: prefer explicit H2D below, to avoid error-prone API usage.
+ ABSL_DEPRECATED(
+ "Prefer SynchronousMemcpyH2D, to avoid error-prone API usage.")
bool SynchronousMemcpy(DeviceMemoryBase *device_dst, const void *host_src,
uint64 size) SE_MUST_USE_RESULT;
// [deprecated] Blocks the caller while a data segment of the given size is
// copied from the device source to the host destination.
- //
- // Deprecation: prefer explicit D2H below, to avoid error-prone API usage.
+ ABSL_DEPRECATED(
+ "Prefer SynchronousMemcpyD2H, to avoid error-prone API usage.")
bool SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &device_src,
uint64 size) SE_MUST_USE_RESULT;
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index adac895a17..cad5de1b0c 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -19,9 +19,18 @@ load(
"@local_config_cuda//cuda:build_defs.bzl",
"cuda_default_copts",
"if_cuda",
+ "if_cuda_is_configured",
+)
+load(
+ "@local_config_rocm//rocm:build_defs.bzl",
+ "if_rocm",
+ "if_rocm_is_configured",
+ "rocm_copts",
+ "rocm_default_copts",
)
load(
"//third_party/mkl:build_defs.bzl",
+ "if_enable_mkl",
"if_mkl",
"if_mkl_lnx_x64",
"if_mkl_ml",
@@ -38,6 +47,8 @@ load(
def register_extension_info(**kwargs):
pass
+# if_cuda_is_configured def placeholder
+
# Given a source file, generate a test name.
# i.e. "common_runtime/direct_session_test.cc" becomes
# "common_runtime_direct_session_test"
@@ -237,6 +248,7 @@ def tf_copts(android_optimization_level_override = "-O2", is_external = False):
if_tensorrt(["-DGOOGLE_TENSORRT=1"]) +
if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML"]) +
if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) +
+ if_enable_mkl(["-DENABLE_MKL"]) +
if_ngraph(["-DINTEL_NGRAPH=1"]) +
if_mkl_lnx_x64(["-fopenmp"]) +
if_android_arm(["-mfpu=neon"]) +
@@ -448,7 +460,7 @@ def tf_gen_op_wrapper_cc(
tf_cc_binary(
name = tool,
copts = tf_copts(),
- linkopts = if_not_windows(["-lm"]),
+ linkopts = if_not_windows(["-lm", "-Wl,-ldl"]),
linkstatic = 1, # Faster to link this one-time-use binary dynamically
deps = [op_gen] + deps,
)
@@ -602,6 +614,7 @@ def tf_gen_op_wrappers_cc(
# is invalid to specify both "hidden" and "op_whitelist".
# cc_linkopts: Optional linkopts to be added to tf_cc_binary that contains the
# specified ops.
+
def tf_gen_op_wrapper_py(
name,
out = None,
@@ -623,7 +636,7 @@ def tf_gen_op_wrapper_py(
deps = [str(Label("//tensorflow/core:" + name + "_op_lib"))]
tf_cc_binary(
name = tool_name,
- linkopts = if_not_windows(["-lm"]) + cc_linkopts,
+ linkopts = if_not_windows(["-lm", "-Wl,-ldl"]) + cc_linkopts,
copts = tf_copts(),
linkstatic = 1, # Faster to link this one-time-use binary dynamically
deps = ([
@@ -860,12 +873,16 @@ def tf_cuda_only_cc_test(
srcs = srcs + tf_binary_additional_srcs(),
size = size,
args = args,
- copts = _cuda_copts() + tf_copts(),
+ copts = _cuda_copts() + rocm_copts() + tf_copts(),
data = data + tf_binary_dynamic_kernel_dsos(kernels),
- deps = deps + tf_binary_dynamic_kernel_deps(kernels) + if_cuda([
- clean_dep("//tensorflow/core:cuda"),
- clean_dep("//tensorflow/core:gpu_lib"),
- ]),
+ deps = deps + tf_binary_dynamic_kernel_deps(kernels) +
+ if_cuda_is_configured([
+ clean_dep("//tensorflow/core:cuda"),
+ clean_dep("//tensorflow/core:gpu_lib"),
+ ]) +
+ if_rocm_is_configured([
+ clean_dep("//tensorflow/core:gpu_lib"),
+ ]),
linkopts = if_not_windows(["-lpthread", "-lm"]) + linkopts + _rpath_linkopts(name),
linkstatic = linkstatic or select({
# cc_tests with ".so"s in srcs incorrectly link on Darwin
@@ -1000,7 +1017,7 @@ register_extension_info(
label_regex_for_dep = "{extension_name}",
)
-def _cuda_copts():
+def _cuda_copts(opts = []):
"""Gets the appropriate set of copts for (maybe) CUDA compilation.
If we're doing CUDA compilation, returns copts for our particular CUDA
@@ -1016,13 +1033,17 @@ def _cuda_copts():
"@local_config_cuda//cuda:using_clang": ([
"-fcuda-flush-denormals-to-zero",
]),
- })
+ }) + if_cuda_is_configured(opts)
# Build defs for TensorFlow kernels
# When this target is built using --config=cuda, a cc_library is built
# that passes -DGOOGLE_CUDA=1 and '-x cuda', linking in additional
# libraries needed by GPU kernels.
+#
+# When this target is built using --config=rocm, a cc_library is built
+# that passes -DTENSORFLOW_USE_ROCM and '-x rocm', linking in additional
+# libraries needed by GPU kernels.
def tf_gpu_kernel_library(
srcs,
copts = [],
@@ -1030,16 +1051,18 @@ def tf_gpu_kernel_library(
deps = [],
hdrs = [],
**kwargs):
- copts = copts + _cuda_copts() + if_cuda(cuda_copts) + tf_copts()
+ copts = copts + tf_copts() + _cuda_copts(opts = cuda_copts) + rocm_copts(opts = cuda_copts)
kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"]
native.cc_library(
srcs = srcs,
hdrs = hdrs,
copts = copts,
- deps = deps + if_cuda([
+ deps = deps + if_cuda_is_configured([
clean_dep("//tensorflow/core:cuda"),
clean_dep("//tensorflow/core:gpu_lib"),
+ ]) + if_rocm_is_configured([
+ clean_dep("//tensorflow/core:gpu_lib"),
]),
alwayslink = 1,
**kwargs
@@ -1078,9 +1101,12 @@ def tf_cuda_library(deps = None, cuda_deps = None, copts = tf_copts(), **kwargs)
deps = deps + if_cuda(cuda_deps + [
clean_dep("//tensorflow/core:cuda"),
"@local_config_cuda//cuda:cuda_headers",
+ ]) + if_rocm_is_configured(cuda_deps + [
+ # rocm_header placeholder
]),
- copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]) +
+ copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_mkl(["-DINTEL_MKL=1"]) +
if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) +
+ if_enable_mkl(["-DENABLE_MKL"]) +
if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
**kwargs
)
@@ -1215,9 +1241,11 @@ def tf_mkl_kernel_library(
if prefix:
srcs = srcs + native.glob(
[prefix + "*.cc"],
+ exclude = [prefix + "*test*"],
)
hdrs = hdrs + native.glob(
[prefix + "*.h"],
+ exclude = [prefix + "*test*"],
)
# -fno-exceptions in nocopts breaks compilation if header modules are enabled.
@@ -1459,6 +1487,9 @@ def tf_custom_op_library(name, srcs = [], gpu_srcs = [], deps = [], linkopts = [
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cudart_static",
]
+ rocm_deps = [
+ clean_dep("//tensorflow/core:stream_executor_headers_lib"),
+ ]
deps = deps + tf_custom_op_library_additional_deps()
if gpu_srcs:
basename = name.split(".")[0]
@@ -1467,13 +1498,14 @@ def tf_custom_op_library(name, srcs = [], gpu_srcs = [], deps = [], linkopts = [
srcs = gpu_srcs,
copts = _cuda_copts() + if_tensorrt(["-DGOOGLE_TENSORRT=1"]),
features = if_cuda(["-use_header_modules"]),
- deps = deps + if_cuda(cuda_deps),
+ deps = deps + if_cuda_is_configured(cuda_deps) + if_rocm_is_configured(rocm_deps),
)
cuda_deps.extend([":" + basename + "_gpu"])
+ rocm_deps.extend([":" + basename + "_gpu"])
check_deps(
name = name + "_check_deps",
- deps = deps + if_cuda(cuda_deps),
+ deps = deps + if_cuda_is_configured(cuda_deps) + if_rocm_is_configured(rocm_deps),
disallowed_deps = [
clean_dep("//tensorflow/core:framework"),
clean_dep("//tensorflow/core:lib"),
@@ -1482,7 +1514,7 @@ def tf_custom_op_library(name, srcs = [], gpu_srcs = [], deps = [], linkopts = [
tf_cc_shared_object(
name = name,
srcs = srcs,
- deps = deps + if_cuda(cuda_deps),
+ deps = deps + if_cuda_is_configured(cuda_deps) + if_rocm_is_configured(rocm_deps),
data = if_static([name + "_check_deps"]),
copts = tf_copts(is_external = True),
features = ["windows_export_all_symbols"],
@@ -1674,7 +1706,7 @@ def py_test(deps = [], data = [], kernels = [], **kwargs):
deps = select({
"//conditions:default": deps,
clean_dep("//tensorflow:no_tensorflow_py_deps"): [],
- }) + tf_binary_dynamic_kernel_deps(kernels),
+ }),
data = data + select({
"//conditions:default": [],
clean_dep("//tensorflow:no_tensorflow_py_deps"): ["//tensorflow/tools/pip_package:win_pip_package_marker"],
@@ -1687,6 +1719,29 @@ register_extension_info(
label_regex_for_dep = "{extension_name}",
)
+# Similar to py_test above, this macro is used to exclude dependencies for some py_binary
+# targets in order to reduce the size of //tensorflow/tools/pip_package:simple_console_windows.
+# See https://github.com/tensorflow/tensorflow/issues/22390
+def py_binary(name, deps = [], **kwargs):
+ # Add an extra target for dependencies to avoid nested select statement.
+ native.py_library(
+ name = name + "_deps",
+ deps = deps,
+ )
+ native.py_binary(
+ name = name,
+ deps = select({
+ "//conditions:default": [":" + name + "_deps"],
+ clean_dep("//tensorflow:no_tensorflow_py_deps"): [],
+ }),
+ **kwargs
+ )
+
+register_extension_info(
+ extension_name = "py_binary",
+ label_regex_for_dep = "{extension_name}",
+)
+
def tf_py_test(
name,
srcs,
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt
index 05698b03ee..af7fc9d4ef 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt
@@ -1,5 +1,6 @@
path: "tensorflow.Variable"
tf_class {
+ is_instance: "<class \'tensorflow.python.ops.variables.VariableV1\'>"
is_instance: "<class \'tensorflow.python.ops.variables.Variable\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
index 87745420ee..825afb622f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
@@ -91,6 +91,10 @@ tf_class {
argspec: "args=[], varargs=args, keywords=None, defaults=None"
}
member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -111,6 +115,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 6dd46365b0..cdad5f6360 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -92,6 +92,10 @@ tf_class {
argspec: "args=[], varargs=args, keywords=None, defaults=None"
}
member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -112,6 +116,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
index 35b7105eba..df41bff1b5 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -92,6 +92,10 @@ tf_class {
argspec: "args=[], varargs=args, keywords=None, defaults=None"
}
member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -112,6 +116,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
index 8ae370af98..028bcc2ce9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
@@ -92,6 +92,10 @@ tf_class {
argspec: "args=[], varargs=args, keywords=None, defaults=None"
}
member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -112,6 +116,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index 7027e78df4..ef3409b1b5 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.estimator.BoostedTreesClassifier"
tf_class {
is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesClassifier\'>"
+ is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees._BoostedTreesBase\'>"
is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
is_instance: "<type \'object\'>"
member {
@@ -32,6 +33,14 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_feature_importances"
+ argspec: "args=[\'self\', \'normalize\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "experimental_predict_with_explanations"
+ argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index d8167ea7cb..775130468f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.estimator.BoostedTreesRegressor"
tf_class {
is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesRegressor\'>"
+ is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees._BoostedTreesBase\'>"
is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
is_instance: "<type \'object\'>"
member {
@@ -32,6 +33,14 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_feature_importances"
+ argspec: "args=[\'self\', \'normalize\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "experimental_predict_with_explanations"
+ argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt
index a308c76ebc..72856466ec 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt
@@ -233,6 +233,14 @@ tf_module {
argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "xdivy"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "xlogy"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "zeta"
argspec: "args=[\'x\', \'q\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index dd9f7c49e0..509ceff9df 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -1093,6 +1093,10 @@ tf_module {
argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "extract_volume_patches"
+ argspec: "args=[\'input\', \'ksizes\', \'strides\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "eye"
argspec: "args=[\'num_rows\', \'num_columns\', \'batch_shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"<dtype: \'float32\'>\", \'None\'], "
}
@@ -1373,6 +1377,10 @@ tf_module {
argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "load_library"
+ argspec: "args=[\'library_location\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "load_op_library"
argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
}
@@ -1426,7 +1434,7 @@ tf_module {
}
member_method {
name: "map_fn"
- argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'True\', \'None\'], "
+ argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'False\', \'True\', \'None\'], "
}
member_method {
name: "matching_files"
@@ -1589,6 +1597,10 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "print"
+ argspec: "args=[], varargs=inputs, keywords=kwargs, defaults=None"
+ }
+ member_method {
name: "py_func"
argspec: "args=[\'func\', \'inp\', \'Tout\', \'stateful\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
@@ -1797,6 +1809,10 @@ tf_module {
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
+ name: "searchsorted"
+ argspec: "args=[\'sorted_sequence\', \'values\', \'side\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'left\', \"<dtype: \'int32\'>\", \'None\'], "
+ }
+ member_method {
name: "segment_max"
argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -2205,6 +2221,10 @@ tf_module {
argspec: "args=[\'max_shard_bytes\', \'axis\', \'bytes_per_string_element\', \'max_shards\'], varargs=None, keywords=None, defaults=[\'0\', \'16\', \'None\'], "
}
member_method {
+ name: "variable_creator_scope"
+ argspec: "args=[\'variable_creator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "variable_op_scope"
argspec: "args=[\'values\', \'name_or_scope\', \'default_name\', \'initializer\', \'regularizer\', \'caching_device\', \'partitioner\', \'custom_getter\', \'reuse\', \'dtype\', \'use_resource\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
index 018be7b9f9..312e94b41d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
@@ -1,12 +1,16 @@
path: "tensorflow.strings"
tf_module {
member_method {
+ name: "format"
+ argspec: "args=[\'template\', \'inputs\', \'placeholder\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'{}\', \'3\', \'None\'], "
+ }
+ member_method {
name: "join"
argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "length"
- argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'input\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], "
}
member_method {
name: "regex_full_match"
@@ -44,4 +48,8 @@ tf_module {
name: "to_number"
argspec: "args=[\'string_tensor\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
+ member_method {
+ name: "unicode_script"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-variable-scope.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-variable-scope.pbtxt
deleted file mode 100644
index c13eb7b8bb..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.-variable-scope.pbtxt
+++ /dev/null
@@ -1,105 +0,0 @@
-path: "tensorflow.VariableScope"
-tf_class {
- is_instance: "<class \'tensorflow.python.ops.variable_scope.VariableScope\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "caching_device"
- mtype: "<type \'property\'>"
- }
- member {
- name: "constraint"
- mtype: "<type \'property\'>"
- }
- member {
- name: "custom_getter"
- mtype: "<type \'property\'>"
- }
- member {
- name: "dtype"
- mtype: "<type \'property\'>"
- }
- member {
- name: "initializer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "original_name_scope"
- mtype: "<type \'property\'>"
- }
- member {
- name: "partitioner"
- mtype: "<type \'property\'>"
- }
- member {
- name: "regularizer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "reuse"
- mtype: "<type \'property\'>"
- }
- member {
- name: "use_resource"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'reuse\', \'name\', \'initializer\', \'regularizer\', \'caching_device\', \'partitioner\', \'custom_getter\', \'name_scope\', \'dtype\', \'use_resource\', \'constraint\'], varargs=None, keywords=None, defaults=[\'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'\', \"<dtype: \'float32\'>\", \'None\', \'None\'], "
- }
- member_method {
- name: "get_collection"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_variable"
- argspec: "args=[\'self\', \'var_store\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'reuse\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
- }
- member_method {
- name: "global_variables"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "local_variables"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "reuse_variables"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "set_caching_device"
- argspec: "args=[\'self\', \'caching_device\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "set_custom_getter"
- argspec: "args=[\'self\', \'custom_getter\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "set_dtype"
- argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "set_initializer"
- argspec: "args=[\'self\', \'initializer\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "set_partitioner"
- argspec: "args=[\'self\', \'partitioner\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "set_regularizer"
- argspec: "args=[\'self\', \'regularizer\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "set_use_resource"
- argspec: "args=[\'self\', \'use_resource\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "trainable_variables"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-variable.-save-slice-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-variable.-save-slice-info.pbtxt
deleted file mode 100644
index ac3ccd468b..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.-variable.-save-slice-info.pbtxt
+++ /dev/null
@@ -1,17 +0,0 @@
-path: "tensorflow.Variable.SaveSliceInfo"
-tf_class {
- is_instance: "<class \'tensorflow.python.ops.variables.SaveSliceInfo\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "spec"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'full_name\', \'full_shape\', \'var_offset\', \'var_shape\', \'save_slice_info_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "to_proto"
- argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-variable.pbtxt
deleted file mode 100644
index 05698b03ee..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.-variable.pbtxt
+++ /dev/null
@@ -1,130 +0,0 @@
-path: "tensorflow.Variable"
-tf_class {
- is_instance: "<class \'tensorflow.python.ops.variables.Variable\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "SaveSliceInfo"
- mtype: "<type \'type\'>"
- }
- member {
- name: "constraint"
- mtype: "<type \'property\'>"
- }
- member {
- name: "device"
- mtype: "<type \'property\'>"
- }
- member {
- name: "dtype"
- mtype: "<type \'property\'>"
- }
- member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
- name: "initial_value"
- mtype: "<type \'property\'>"
- }
- member {
- name: "initializer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "op"
- mtype: "<type \'property\'>"
- }
- member {
- name: "shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
- }
- member_method {
- name: "assign"
- argspec: "args=[\'self\', \'value\', \'use_locking\', \'name\', \'read_value\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'True\'], "
- }
- member_method {
- name: "assign_add"
- argspec: "args=[\'self\', \'delta\', \'use_locking\', \'name\', \'read_value\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'True\'], "
- }
- member_method {
- name: "assign_sub"
- argspec: "args=[\'self\', \'delta\', \'use_locking\', \'name\', \'read_value\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'True\'], "
- }
- member_method {
- name: "count_up_to"
- argspec: "args=[\'self\', \'limit\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "eval"
- argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "from_proto"
- argspec: "args=[\'variable_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "get_shape"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "initialized_value"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "load"
- argspec: "args=[\'self\', \'value\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "read_value"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "scatter_add"
- argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "scatter_nd_add"
- argspec: "args=[\'self\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "scatter_nd_sub"
- argspec: "args=[\'self\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "scatter_nd_update"
- argspec: "args=[\'self\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "scatter_sub"
- argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "scatter_update"
- argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "set_shape"
- argspec: "args=[\'self\', \'shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "to_proto"
- argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "value"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
index 87745420ee..825afb622f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
@@ -91,6 +91,10 @@ tf_class {
argspec: "args=[], varargs=args, keywords=None, defaults=None"
}
member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -111,6 +115,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 6dd46365b0..cdad5f6360 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -92,6 +92,10 @@ tf_class {
argspec: "args=[], varargs=args, keywords=None, defaults=None"
}
member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -112,6 +116,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
index 35b7105eba..df41bff1b5 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -92,6 +92,10 @@ tf_class {
argspec: "args=[], varargs=args, keywords=None, defaults=None"
}
member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -112,6 +116,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
index 8ae370af98..028bcc2ce9 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
@@ -92,6 +92,10 @@ tf_class {
argspec: "args=[], varargs=args, keywords=None, defaults=None"
}
member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -112,6 +116,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index 7027e78df4..ef3409b1b5 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.estimator.BoostedTreesClassifier"
tf_class {
is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesClassifier\'>"
+ is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees._BoostedTreesBase\'>"
is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
is_instance: "<type \'object\'>"
member {
@@ -32,6 +33,14 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_feature_importances"
+ argspec: "args=[\'self\', \'normalize\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "experimental_predict_with_explanations"
+ argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index d8167ea7cb..775130468f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.estimator.BoostedTreesRegressor"
tf_class {
is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesRegressor\'>"
+ is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees._BoostedTreesBase\'>"
is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
is_instance: "<type \'object\'>"
member {
@@ -32,6 +33,14 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_feature_importances"
+ argspec: "args=[\'self\', \'normalize\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "experimental_predict_with_explanations"
+ argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt
index d499c67d89..e3c63fe737 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt
@@ -49,10 +49,6 @@ tf_module {
mtype: "<type \'type\'>"
}
member_method {
- name: "global_variables"
- argspec: "args=[], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "he_normal"
argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -68,12 +64,4 @@ tf_module {
name: "lecun_uniform"
argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
- member_method {
- name: "local_variables"
- argspec: "args=[], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "variables"
- argspec: "args=[\'var_list\', \'name\'], varargs=None, keywords=None, defaults=[\'init\'], "
- }
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt
index a308c76ebc..72856466ec 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt
@@ -233,6 +233,14 @@ tf_module {
argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "xdivy"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "xlogy"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "zeta"
argspec: "args=[\'x\', \'q\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index 9332e16bf6..d2dc8bc85f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -1,10 +1,6 @@
path: "tensorflow"
tf_module {
member {
- name: "AUTO_REUSE"
- mtype: "<enum \'_ReuseMode\'>"
- }
- member {
name: "AggregationMethod"
mtype: "<type \'type\'>"
}
@@ -233,18 +229,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
- name: "Variable"
- mtype: "<class \'tensorflow.python.ops.variables.VariableMetaclass\'>"
- }
- member {
name: "VariableAggregation"
mtype: "<class \'enum.EnumMeta\'>"
}
member {
- name: "VariableScope"
- mtype: "<type \'type\'>"
- }
- member {
name: "VariableSynchronization"
mtype: "<class \'enum.EnumMeta\'>"
}
@@ -553,10 +541,6 @@ tf_module {
mtype: "<type \'module\'>"
}
member {
- name: "variable_scope"
- mtype: "<type \'type\'>"
- }
- member {
name: "variance_scaling_initializer"
mtype: "<type \'type\'>"
}
@@ -581,10 +565,6 @@ tf_module {
argspec: "args=[\'op_type\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "Print"
- argspec: "args=[\'input_\', \'data\', \'message\', \'first_n\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
name: "abs"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -621,10 +601,6 @@ tf_module {
argspec: "args=[\'names\', \'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "all_variables"
- argspec: "args=[], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "angle"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -737,10 +713,6 @@ tf_module {
argspec: "args=[\'tensor\', \'tf_type\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
- name: "assert_variables_initialized"
- argspec: "args=[\'var_list\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
name: "atan"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1041,6 +1013,10 @@ tf_module {
argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "extract_volume_patches"
+ argspec: "args=[\'input\', \'ksizes\', \'strides\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "eye"
argspec: "args=[\'num_rows\', \'num_columns\', \'batch_shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"<dtype: \'float32\'>\", \'None\'], "
}
@@ -1137,10 +1113,6 @@ tf_module {
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_local_variable"
- argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
- }
- member_method {
name: "get_seed"
argspec: "args=[\'op_seed\'], varargs=None, keywords=None, defaults=None"
}
@@ -1153,26 +1125,10 @@ tf_module {
argspec: "args=[\'handle\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "get_variable"
- argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
- }
- member_method {
- name: "get_variable_scope"
- argspec: "args=[], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "global_norm"
argspec: "args=[\'t_list\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "global_variables"
- argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "global_variables_initializer"
- argspec: "args=[], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "gradients"
argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\'], "
}
@@ -1249,18 +1205,6 @@ tf_module {
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'init_all_tables\'], "
}
member_method {
- name: "initialize_all_variables"
- argspec: "args=[], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "initialize_local_variables"
- argspec: "args=[], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "initialize_variables"
- argspec: "args=[\'var_list\', \'name\'], varargs=None, keywords=None, defaults=[\'init\'], "
- }
- member_method {
name: "invert_permutation"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1289,10 +1233,6 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "is_variable_initialized"
- argspec: "args=[\'variable\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "lbeta"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1321,16 +1261,12 @@ tf_module {
argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "load_op_library"
- argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "local_variables"
- argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ name: "load_library"
+ argspec: "args=[\'library_location\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "local_variables_initializer"
- argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ name: "load_op_library"
+ argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "log"
@@ -1374,7 +1310,7 @@ tf_module {
}
member_method {
name: "map_fn"
- argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'True\', \'None\'], "
+ argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'False\', \'True\', \'None\'], "
}
member_method {
name: "matching_files"
@@ -1445,14 +1381,6 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "model_variables"
- argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "moving_average_variables"
- argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
name: "multinomial"
argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'name\', \'output_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
@@ -1537,6 +1465,10 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "print"
+ argspec: "args=[], varargs=inputs, keywords=kwargs, defaults=None"
+ }
+ member_method {
name: "py_func"
argspec: "args=[\'func\', \'inp\', \'Tout\', \'stateful\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
@@ -1649,10 +1581,6 @@ tf_module {
argspec: "args=[\'base_type\', \'conversion_func\', \'priority\'], varargs=None, keywords=None, defaults=[\'100\'], "
}
member_method {
- name: "report_uninitialized_variables"
- argspec: "args=[\'var_list\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'report_uninitialized_variables\'], "
- }
- member_method {
name: "required_space_to_batch_paddings"
argspec: "args=[\'input_shape\', \'block_shape\', \'base_paddings\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
@@ -1721,6 +1649,10 @@ tf_module {
argspec: "args=[\'indices\', \'updates\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "searchsorted"
+ argspec: "args=[\'sorted_sequence\', \'values\', \'side\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'left\', \"<dtype: \'int32\'>\", \'None\'], "
+ }
+ member_method {
name: "segment_max"
argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -2057,10 +1989,6 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "trainable_variables"
- argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
name: "transpose"
argspec: "args=[\'a\', \'perm\', \'name\', \'conjugate\'], varargs=None, keywords=None, defaults=[\'None\', \'transpose\', \'False\'], "
}
@@ -2129,14 +2057,6 @@ tf_module {
argspec: "args=[\'max_shard_bytes\', \'axis\', \'bytes_per_string_element\', \'max_shards\'], varargs=None, keywords=None, defaults=[\'0\', \'16\', \'None\'], "
}
member_method {
- name: "variable_op_scope"
- argspec: "args=[\'values\', \'name_or_scope\', \'default_name\', \'initializer\', \'regularizer\', \'caching_device\', \'partitioner\', \'custom_getter\', \'reuse\', \'dtype\', \'use_resource\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "variables_initializer"
- argspec: "args=[\'var_list\', \'name\'], varargs=None, keywords=None, defaults=[\'init\'], "
- }
- member_method {
name: "verify_tensor_all_finite"
argspec: "args=[\'t\', \'msg\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
index 018be7b9f9..312e94b41d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
@@ -1,12 +1,16 @@
path: "tensorflow.strings"
tf_module {
member_method {
+ name: "format"
+ argspec: "args=[\'template\', \'inputs\', \'placeholder\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'{}\', \'3\', \'None\'], "
+ }
+ member_method {
name: "join"
argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "length"
- argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'input\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], "
}
member_method {
name: "regex_full_match"
@@ -44,4 +48,8 @@ tf_module {
name: "to_number"
argspec: "args=[\'string_tensor\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
+ member_method {
+ name: "unicode_script"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
index b21dabbde7..cb6da5088b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
@@ -265,10 +265,6 @@ tf_module {
argspec: "args=[\'graph\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "do_quantize_training_on_graphdef"
- argspec: "args=[\'input_graph\', \'num_bits\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "exponential_decay"
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.variable_scope.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.variable_scope.pbtxt
deleted file mode 100644
index e62dec93e6..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.variable_scope.pbtxt
+++ /dev/null
@@ -1,9 +0,0 @@
-path: "tensorflow.variable_scope"
-tf_class {
- is_instance: "<class \'tensorflow.python.ops.variable_scope.variable_scope\'>"
- is_instance: "<type \'object\'>"
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'name_or_scope\', \'default_name\', \'values\', \'initializer\', \'regularizer\', \'caching_device\', \'partitioner\', \'custom_getter\', \'reuse\', \'dtype\', \'use_resource\', \'constraint\', \'auxiliary_name_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\'], "
- }
-}
diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD
index 4efa4a9651..3cbea41dca 100644
--- a/tensorflow/tools/api/tests/BUILD
+++ b/tensorflow/tools/api/tests/BUILD
@@ -19,6 +19,7 @@ py_test(
"api_compatibility_test.py",
"//tensorflow:tf_python_api_gen_v2",
],
+ args = ["--only_test_core_api=true"],
data = [
"//tensorflow/tools/api/golden:api_golden_v1",
"//tensorflow/tools/api/golden:api_golden_v2",
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py
index d06c7f2d49..6487a6267e 100644
--- a/tensorflow/tools/api/tests/api_compatibility_test.py
+++ b/tensorflow/tools/api/tests/api_compatibility_test.py
@@ -56,6 +56,14 @@ _UPDATE_GOLDENS_HELP = """
have to be authorized by TensorFlow leads.
"""
+# DEFINE_boolean, only_test_core_api, default False:
+_ONLY_TEST_CORE_API_HELP = """
+ Some TF APIs are being moved outside of the tensorflow/ directory. There is
+ no garuntee which versions of these APIs will be present when running this
+ test. Therefore, do not error out on API changes in non-core TF code
+ if this flag is set.
+"""
+
# DEFINE_boolean, verbose_diffs, default True:
_VERBOSE_DIFFS_HELP = """
If set to true, print line by line diffs on all libraries. If set to
@@ -67,6 +75,8 @@ _API_GOLDEN_FOLDER_V2 = 'tensorflow/tools/api/golden/v2'
_TEST_README_FILE = 'tensorflow/tools/api/tests/README.txt'
_UPDATE_WARNING_FILE = 'tensorflow/tools/api/tests/API_UPDATE_WARNING.txt'
+_NON_CORE_PACKAGES = ['estimator']
+
def _KeyToFilePath(key, api_version):
"""From a given key, construct a filepath.
@@ -111,6 +121,19 @@ def _VerifyNoSubclassOfMessageVisitor(path, parent, unused_children):
'They are not yet supported by the API tools.' % path)
+def _FilterNonCoreGoldenFiles(golden_file_list):
+ """Filter out non-core API pbtxt files."""
+ filtered_file_list = []
+ filtered_package_prefixes = [
+ 'tensorflow.%s.' % p for p in _NON_CORE_PACKAGES]
+ for f in golden_file_list:
+ if any([f.rsplit('/')[-1].startswith(pre)
+ for pre in filtered_package_prefixes]):
+ continue
+ filtered_file_list.append(f)
+ return filtered_file_list
+
+
class ApiCompatibilityTest(test.TestCase):
def __init__(self, *args, **kwargs):
@@ -233,6 +256,9 @@ class ApiCompatibilityTest(test.TestCase):
return
visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
visitor.do_not_descend_map['tf'].append('contrib')
+ if FLAGS.only_test_core_api:
+ visitor.do_not_descend_map['tf'].extend(
+ _NON_CORE_PACKAGES)
traverse.traverse(tf_v2.compat.v1, visitor)
def testNoSubclassOfMessageV2(self):
@@ -240,6 +266,9 @@ class ApiCompatibilityTest(test.TestCase):
return
visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
visitor.do_not_descend_map['tf'].append('contrib')
+ if FLAGS.only_test_core_api:
+ visitor.do_not_descend_map['tf'].extend(
+ _NON_CORE_PACKAGES)
traverse.traverse(tf_v2, visitor)
def _checkBackwardsCompatibility(
@@ -252,6 +281,9 @@ class ApiCompatibilityTest(test.TestCase):
public_api_visitor.do_not_descend_map['tf'].append('contrib')
public_api_visitor.do_not_descend_map['tf.GPUOptions'] = [
'Experimental']
+ if FLAGS.only_test_core_api:
+ public_api_visitor.do_not_descend_map['tf'].extend(
+ _NON_CORE_PACKAGES)
if additional_private_map:
public_api_visitor.private_map.update(additional_private_map)
@@ -260,6 +292,8 @@ class ApiCompatibilityTest(test.TestCase):
# Read all golden files.
golden_file_list = file_io.get_matching_files(golden_file_pattern)
+ if FLAGS.only_test_core_api:
+ golden_file_list = _FilterNonCoreGoldenFiles(golden_file_list)
def _ReadFileToProto(filename):
"""Read a filename, create a protobuf from its contents."""
@@ -325,6 +359,11 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--update_goldens', type=bool, default=False, help=_UPDATE_GOLDENS_HELP)
+ # TODO(mikecase): Create Estimator's own API compatibility test or
+ # a more general API compatibility test for use for TF components.
+ parser.add_argument(
+ '--only_test_core_api', type=bool, default=False,
+ help=_ONLY_TEST_CORE_API_HELP)
parser.add_argument(
'--verbose_diffs', type=bool, default=True, help=_VERBOSE_DIFFS_HELP)
FLAGS, unparsed = parser.parse_known_args()
diff --git a/tensorflow/tools/benchmark/README.md b/tensorflow/tools/benchmark/README.md
index e64af2bfe1..dee1a20f3f 100644
--- a/tensorflow/tools/benchmark/README.md
+++ b/tensorflow/tools/benchmark/README.md
@@ -32,7 +32,7 @@ adb push bazel-bin/tensorflow/tools/benchmark/benchmark_model /data/local/tmp
(4) Run the benchmark. For example:
```
-adb shell "/data/local/tmp/benchmark_model \
+adb shell /data/local/tmp/benchmark_model \
--graph=/data/local/tmp/tensorflow_inception_graph.pb \
--input_layer="input:0" \
--input_layer_shape="1,224,224,3" \
diff --git a/tensorflow/tools/ci_build/Dockerfile.rocm b/tensorflow/tools/ci_build/Dockerfile.rocm
new file mode 100644
index 0000000000..aadaa8bac1
--- /dev/null
+++ b/tensorflow/tools/ci_build/Dockerfile.rocm
@@ -0,0 +1,97 @@
+# This Dockerfile provides a starting point for a ROCm installation of
+# MIOpen and tensorflow.
+FROM ubuntu:xenial
+MAINTAINER Jeff Poznanovic <jeffrey.poznanovic@amd.com>
+
+ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/debian/
+ARG ROCM_PATH=/opt/rocm
+
+ENV DEBIAN_FRONTEND noninteractive
+ENV TF_NEED_ROCM 1
+ENV HOME /root/
+RUN apt update && apt install -y wget software-properties-common
+
+# Add rocm repository
+RUN apt-get clean all
+RUN wget -qO - $DEB_ROCM_REPO/rocm.gpg.key | apt-key add -
+RUN sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO xenial main > /etc/apt/sources.list.d/rocm.list"
+
+# Install misc pkgs
+RUN apt-get update --allow-insecure-repositories && DEBIAN_FRONTEND=noninteractive apt-get install -y \
+ build-essential \
+ clang-3.8 \
+ clang-format-3.8 \
+ clang-tidy-3.8 \
+ cmake \
+ cmake-qt-gui \
+ ssh \
+ curl \
+ apt-utils \
+ pkg-config \
+ g++-multilib \
+ git \
+ libunwind-dev \
+ libfftw3-dev \
+ libelf-dev \
+ libncurses5-dev \
+ libpthread-stubs0-dev \
+ vim \
+ gfortran \
+ libboost-program-options-dev \
+ libssl-dev \
+ libboost-dev \
+ libboost-system-dev \
+ libboost-filesystem-dev \
+ rpm \
+ libnuma-dev \
+ virtualenv \
+ python-pip \
+ python3-pip \
+ wget && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+# Install rocm pkgs
+RUN apt-get update --allow-insecure-repositories && \
+ DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
+ rocm-dev rocm-libs rocm-utils \
+ rocfft miopen-hip miopengemm rocblas hipblas rocrand \
+ rocm-profiler cxlactivitylogger && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN cd ~ && git clone https://github.com/GPUOpen-ProfessionalCompute-Tools/HIP.git
+RUN cd ~/HIP && mkdir -p build && cd build && cmake .. && make package -j && dpkg -i *.deb
+
+ENV HCC_HOME=$ROCM_PATH/hcc
+ENV HIP_PATH=$ROCM_PATH/hip
+ENV OPENCL_ROOT=$ROCM_PATH/opencl
+ENV PATH="$HCC_HOME/bin:$HIP_PATH/bin:${PATH}"
+ENV PATH="$ROCM_PATH/bin:${PATH}"
+ENV PATH="$OPENCL_ROOT/bin:${PATH}"
+
+# Add target file to help determine which device(s) to build for
+RUN echo -e "gfx803\ngfx900" >> /opt/rocm/bin/target.lst
+
+# Setup environment variables, and add those environment variables at the end of ~/.bashrc
+ARG HCC_HOME=/opt/rocm/hcc
+ARG HIP_PATH=/opt/rocm/hip
+ARG PATH=$HCC_HOME/bin:$HIP_PATH/bin:$PATH
+
+# Copy and run the install scripts.
+COPY install/*.sh /install/
+ARG DEBIAN_FRONTEND=noninteractive
+RUN /install/install_bootstrap_deb_packages.sh
+RUN add-apt-repository -y ppa:openjdk-r/ppa && \
+ add-apt-repository -y ppa:george-edison55/cmake-3.x
+RUN /install/install_deb_packages.sh
+RUN /install/install_pip_packages.sh
+RUN /install/install_bazel.sh
+RUN /install/install_golang.sh
+
+# Set up the master bazelrc configuration file.
+COPY install/.bazelrc /etc/bazel.bazelrc
+
+# Configure the build for our CUDA configuration.
+ENV TF_NEED_ROCM 1
+
diff --git a/tensorflow/tools/ci_build/README.md b/tensorflow/tools/ci_build/README.md
index f2161b700a..e2fd977f50 100644
--- a/tensorflow/tools/ci_build/README.md
+++ b/tensorflow/tools/ci_build/README.md
@@ -24,7 +24,7 @@ natively on your system.
### Run TensorFlow CI Scripts Natively on your Machine
-1. Follow the instructions at https://www.tensorflow.org/install/install_sources,
+1. Follow the instructions at https://www.tensorflow.org/install/source,
but stop when you get to the section "Configure the installation". You do not
need to configure the installation to run the CI scripts.
diff --git a/tensorflow/tools/ci_build/builds/docker_test.sh b/tensorflow/tools/ci_build/builds/docker_test.sh
index e337ea4b05..38891b60e5 100755
--- a/tensorflow/tools/ci_build/builds/docker_test.sh
+++ b/tensorflow/tools/ci_build/builds/docker_test.sh
@@ -19,7 +19,7 @@
#
# Usage: docker_test.sh <IMAGE_TYPE> <TAG> <WHL_PATH>
# Arguments:
-# IMAGE_TYPE : Type of the image: (CPU|GPU)
+# IMAGE_TYPE : Type of the image: (CPU|GPU|ROCM)
# TAG : Docker image tag
# WHL_PATH : Path to the whl file to be installed inside the docker image
#
@@ -60,6 +60,8 @@ if [[ "${IMAGE_TYPE}" == "cpu" ]]; then
DOCKERFILE="tensorflow/tools/docker/Dockerfile"
elif [[ "${IMAGE_TYPE}" == "gpu" ]]; then
DOCKERFILE="tensorflow/tools/docker/Dockerfile.gpu"
+elif [[ "${IMAGE_TYPE}" == "rocm" ]]; then
+ DOCKERFILE="tensorflow/tools/docker/Dockerfile.rocm"
else
die "Unrecognized image type: $1"
fi
@@ -106,13 +108,16 @@ if [ "${IMAGE_TYPE}" == "gpu" ]; then
devices=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}')
libs=$(\ls /usr/lib/x86_64-linux-gnu/libcuda.* | xargs -I{} echo '-v {}:{}')
GPU_EXTRA_PARAMS="${devices} ${libs}"
+elif [ "${IMAGE_TYPE}" == "rocm" ]; then
+ ROCM_EXTRA_PARAMS="--device=/dev/kfd --device=/dev/dri --group-add video"
else
GPU_EXTRA_PARAMS=""
+ ROCM_EXTRA_PARAMS=""
fi
# Run docker image with source directory mapped
docker run -v ${BASE_DIR}:/tensorflow-src -w /tensorflow-src \
-${GPU_EXTRA_PARAMS} \
+${GPU_EXTRA_PARAMS} ${ROCM_EXTRA_PARAMS} \
"${DOCKER_IMG_TAG}" \
/bin/bash -c "tensorflow/tools/ci_build/builds/run_pip_tests.sh && "\
"tensorflow/tools/ci_build/builds/test_tutorials.sh && "\
diff --git a/tensorflow/tools/ci_build/builds/pip.sh b/tensorflow/tools/ci_build/builds/pip.sh
index fef121ab5a..6543779022 100755
--- a/tensorflow/tools/ci_build/builds/pip.sh
+++ b/tensorflow/tools/ci_build/builds/pip.sh
@@ -132,6 +132,7 @@ echo "Using Bazel flags: ${BAZEL_FLAGS}"
PIP_BUILD_TARGET="//tensorflow/tools/pip_package:build_pip_package"
GPU_FLAG=""
if [[ ${CONTAINER_TYPE} == "cpu" ]] || \
+ [[ ${CONTAINER_TYPE} == "rocm" ]] || \
[[ ${CONTAINER_TYPE} == "debian.jessie.cpu" ]]; then
bazel build ${BAZEL_FLAGS} ${PIP_BUILD_TARGET} || \
die "Build failed."
@@ -255,7 +256,8 @@ if [[ $(uname) == "Linux" ]]; then
die "ERROR: Cannot find repaired wheel."
fi
# Copy and rename for gpu manylinux as we do not want auditwheel to package in libcudart.so
- elif [[ ${CONTAINER_TYPE} == "gpu" ]]; then
+ elif [[ ${CONTAINER_TYPE} == "gpu" ]] || \
+ [[ ${CONTAINER_TYPE} == "rocm" ]]; then
WHL_PATH=${AUDITED_WHL_NAME}
cp ${WHL_DIR}/${WHL_BASE_NAME} ${WHL_PATH}
echo "Copied manylinx1 wheel file at ${WHL_PATH}"
diff --git a/tensorflow/tools/ci_build/builds/run_pip_tests.sh b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
index 4b762bf258..7d5cf3f843 100755
--- a/tensorflow/tools/ci_build/builds/run_pip_tests.sh
+++ b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
@@ -64,7 +64,7 @@ while true; do
fi
done
-TF_GPU_COUNT=${TF_GPU_COUNT:-8}
+TF_GPU_COUNT=${TF_GPU_COUNT:-4}
# PIP tests should have a "different" path. Different than the one we place
# virtualenv, because we are deleting and recreating it here.
@@ -111,7 +111,6 @@ bazel clean
# virtualenv.
export TF_NEED_GCP=0
export TF_NEED_HDFS=0
-export TF_ENABLE_XLA=0
# Obtain the path to Python binary
if [[ ${IS_VIRTUALENV} == "1" ]]; then
diff --git a/tensorflow/tools/ci_build/builds/with_the_same_user b/tensorflow/tools/ci_build/builds/with_the_same_user
index b216e3549f..1cc5aed15d 100755
--- a/tensorflow/tools/ci_build/builds/with_the_same_user
+++ b/tensorflow/tools/ci_build/builds/with_the_same_user
@@ -48,6 +48,12 @@ getent passwd "${CI_BUILD_UID}" || adduser ${ADDUSER_OPTS} \
usermod -a -G sudo "${CI_BUILD_USER}"
echo "${CI_BUILD_USER} ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-nopasswd-sudo
+if [[ "${TF_NEED_ROCM}" -eq 1 ]]; then
+ # ROCm requires the video group in order to use the GPU for compute. If it
+ # exists on the host, add it to the container.
+ getent group video || addgroup video && adduser "${CI_BUILD_USER}" video
+fi
+
if [ -e /root/.bazelrc ]; then
cp /root/.bazelrc "${CI_BUILD_HOME}/.bazelrc"
chown "${CI_BUILD_UID}:${CI_BUILD_GID}" "${CI_BUILD_HOME}/.bazelrc"
diff --git a/tensorflow/tools/ci_build/ci_build.sh b/tensorflow/tools/ci_build/ci_build.sh
index 77265e0f50..eab0616513 100755
--- a/tensorflow/tools/ci_build/ci_build.sh
+++ b/tensorflow/tools/ci_build/ci_build.sh
@@ -18,7 +18,7 @@
# <COMMAND>
#
# CONTAINER_TYPE: Type of the docker container used the run the build:
-# e.g., (cpu | gpu | android | tensorboard)
+# e.g., (cpu | gpu | rocm | android | tensorboard)
#
# DOCKERFILE_PATH: (Optional) Path to the Dockerfile used for docker build.
# If this optional value is not supplied (via the
@@ -103,6 +103,14 @@ if [[ "${CONTAINER_TYPE}" != gpu* ]]; then
GPU_EXTRA_PARAMS=""
fi
+# Add extra params for rocm devices and libraries for ROCm container.
+if [[ "${CONTAINER_TYPE}" == "rocm" ]]; then
+ ROCM_EXTRA_PARAMS="--device=/dev/kfd --device=/dev/dri --group-add video"
+else
+ ROCM_EXTRA_PARAMS=""
+fi
+
+
# Determine the docker image name
DOCKER_IMG_NAME="${BUILD_TAG}.${CONTAINER_TYPE}"
@@ -159,6 +167,7 @@ ${DOCKER_BINARY} run --rm --pid=host \
-v ${WORKSPACE}:/workspace \
-w /workspace \
${GPU_EXTRA_PARAMS} \
+ ${ROCM_EXTRA_PARAMS} \
${CI_DOCKER_EXTRA_PARAMS[@]} \
"${DOCKER_IMG_NAME}" \
${CI_COMMAND_PREFIX[@]} \
diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh
index cc09784c1d..49a9048c03 100755
--- a/tensorflow/tools/ci_build/ci_parameterized_build.sh
+++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh
@@ -147,7 +147,7 @@ PIP_INTEGRATION_TESTS_FLAG="--integration_tests"
ANDROID_CMD="${CI_BUILD_DIR}/builds/android.sh"
ANDROID_FULL_CMD="${CI_BUILD_DIR}/builds/android_full.sh"
-TF_GPU_COUNT=${TF_GPU_COUNT:-8}
+TF_GPU_COUNT=${TF_GPU_COUNT:-4}
PARALLEL_GPU_TEST_CMD='//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute'
BENCHMARK_CMD="${CI_BUILD_DIR}/builds/benchmark.sh"
diff --git a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
index 03a2a07fb1..cd7206baf8 100755
--- a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
+++ b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
@@ -21,8 +21,8 @@
# Required environment variables:
# TF_GPU_COUNT = Number of GPUs available.
-TF_GPU_COUNT=${TF_GPU_COUNT:-8}
-TF_TESTS_PER_GPU=${TF_TESTS_PER_GPU:-4}
+TF_GPU_COUNT=${TF_GPU_COUNT:-4}
+TF_TESTS_PER_GPU=${TF_TESTS_PER_GPU:-8}
# We want to allow running one of the following configs:
# - 4 tests per GPU on k80
# - 8 tests per GPU on p100
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index a9ae715c6a..4ced96f90b 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -68,8 +68,8 @@ else
pip3 install --upgrade numpy==1.14.5
fi
-pip2 install scipy==0.18.1
-pip3 install scipy==0.18.1
+pip2 install scipy==1.1.0
+pip3 install scipy==1.1.0
pip2 install scikit-learn==0.18.1
pip3 install scikit-learn==0.18.1
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh
index 8eeddcdb82..3b5c92d148 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh
@@ -26,6 +26,7 @@ echo ""
# Run configure.
export TF_NEED_CUDA=0
+export TF_NEED_ROCM=0
export CC_OPT_FLAGS='-mavx'
# Only running cc tests, python version does not matter.
export PYTHON_BIN_PATH=`which python`
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh
index 8eca1987f0..52eff6330f 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh
@@ -26,6 +26,7 @@ echo ""
# Run configure.
export TF_NEED_CUDA=0
+export TF_NEED_ROCM=0
export CC_OPT_FLAGS='-mavx'
export PYTHON_BIN_PATH=`which python2`
yes "" | $PYTHON_BIN_PATH configure.py
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
index f6fa9251d4..d12027599a 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
@@ -26,6 +26,7 @@ echo ""
# Run configure.
export TF_NEED_CUDA=0
+export TF_NEED_ROCM=0
export CC_OPT_FLAGS='-mavx'
export PYTHON_BIN_PATH=`which python3`
yes "" | $PYTHON_BIN_PATH configure.py
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh
index 51eb2cd7e6..7c531a4d68 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh
@@ -26,6 +26,7 @@ echo ""
# Run configure.
export TF_NEED_CUDA=0
+export TF_NEED_ROCM=0
export CC_OPT_FLAGS='-mavx'
export PYTHON_BIN_PATH=`which python3`
yes "" | $PYTHON_BIN_PATH configure.py
diff --git a/tensorflow/tools/ci_build/linux/libtensorflow.sh b/tensorflow/tools/ci_build/linux/libtensorflow.sh
index beef8e063b..3b6e15feb9 100755
--- a/tensorflow/tools/ci_build/linux/libtensorflow.sh
+++ b/tensorflow/tools/ci_build/linux/libtensorflow.sh
@@ -27,5 +27,8 @@ SUFFIX="-cpu-linux-"
if [ "${TF_NEED_CUDA}" == "1" ]; then
SUFFIX="-gpu-linux-"
fi
+if [ "${TF_NEED_ROCM}" == "1" ]; then
+ SUFFIX="-rocm-linux-"
+fi
build_libtensorflow_tarball "${SUFFIX}$(uname -m)"
diff --git a/tensorflow/tools/ci_build/linux/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/linux/libtensorflow_cpu.sh
index 4bf34dd299..b76262b6e9 100755
--- a/tensorflow/tools/ci_build/linux/libtensorflow_cpu.sh
+++ b/tensorflow/tools/ci_build/linux/libtensorflow_cpu.sh
@@ -19,4 +19,5 @@
set -ex
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
export TF_NEED_CUDA=0
+export TF_NEED_ROCM=0
"${SCRIPT_DIR}/libtensorflow_docker.sh"
diff --git a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
index 60c974c36b..467b8dc808 100755
--- a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
+++ b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
@@ -38,6 +38,11 @@ if [ "${TF_NEED_CUDA}" == "1" ]; then
DOCKER_BINARY="nvidia-docker"
DOCKER_FILE="Dockerfile.gpu"
fi
+if [ "${TF_NEED_ROCM}" == "1" ]; then
+ DOCKER_IMAGE="tf-tensorflow-rocm"
+ DOCKER_BINARY="docker"
+ DOCKER_FILE="Dockerfile.rocm"
+fi
docker build \
-t "${DOCKER_IMAGE}" \
@@ -53,6 +58,7 @@ ${DOCKER_BINARY} run \
-e "TF_NEED_HDFS=0" \
-e "TF_NEED_CUDA=${TF_NEED_CUDA}" \
-e "TF_NEED_TENSORRT=${TF_NEED_CUDA}" \
+ -e "TF_NEED_ROCM=${TF_NEED_ROCM}" \
-e "TF_NEED_OPENCL_SYCL=0" \
"${DOCKER_IMAGE}" \
"/workspace/tensorflow/tools/ci_build/linux/libtensorflow.sh"
diff --git a/tensorflow/contrib/linalg/python/__init__.py b/tensorflow/tools/ci_build/linux/libtensorflow_rocm.sh
index c5ca3a623f..c1ebbe3630 100644..100755
--- a/tensorflow/contrib/linalg/python/__init__.py
+++ b/tensorflow/tools/ci_build/linux/libtensorflow_rocm.sh
@@ -1,3 +1,4 @@
+#!/usr/bin/env bash
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""ops module."""
+#
+# Script to build a binary releases of libtensorflow with GPU support.
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+set -ex
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+export TF_NEED_ROCM=1
+"${SCRIPT_DIR}/libtensorflow_docker.sh"
diff --git a/tensorflow/tools/ci_build/linux/rocm/run_cc_core.sh b/tensorflow/tools/ci_build/linux/rocm/run_cc_core.sh
new file mode 100755
index 0000000000..200089f90e
--- /dev/null
+++ b/tensorflow/tools/ci_build/linux/rocm/run_cc_core.sh
@@ -0,0 +1,39 @@
+#!/usr/bin/env bash
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+set -e
+set -x
+
+N_JOBS=$(grep -c ^processor /proc/cpuinfo)
+
+echo ""
+echo "Bazel will use ${N_JOBS} concurrent job(s)."
+echo ""
+
+# Run configure.
+export PYTHON_BIN_PATH=`which python3`
+export CC_OPT_FLAGS='-mavx'
+
+export TF_NEED_ROCM=1
+
+yes "" | $PYTHON_BIN_PATH configure.py
+
+# Run bazel test command. Double test timeouts to avoid flakes.
+bazel test --config=rocm --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test -k \
+ --test_lang_filters=cc --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
+ --build_tests_only --test_output=errors --local_test_jobs=1 --config=opt \
+ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/linux/rocm/run_py3_core.sh b/tensorflow/tools/ci_build/linux/rocm/run_py3_core.sh
new file mode 100755
index 0000000000..1d0b838c1b
--- /dev/null
+++ b/tensorflow/tools/ci_build/linux/rocm/run_py3_core.sh
@@ -0,0 +1,39 @@
+#!/usr/bin/env bash
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+set -e
+set -x
+
+N_JOBS=$(grep -c ^processor /proc/cpuinfo)
+
+echo ""
+echo "Bazel will use ${N_JOBS} concurrent job(s)."
+echo ""
+
+# Run configure.
+export PYTHON_BIN_PATH=`which python3`
+export CC_OPT_FLAGS='-mavx'
+
+export TF_NEED_ROCM=1
+
+yes "" | $PYTHON_BIN_PATH configure.py
+
+# Run bazel test command. Double test timeouts to avoid flakes.
+bazel test --config=rocm --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test -k \
+ --test_lang_filters=py --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
+ --build_tests_only --test_output=errors --local_test_jobs=1 --config=opt \
+ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh
index c7cc16e669..adee0d3171 100755
--- a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh
+++ b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh
@@ -27,6 +27,7 @@ echo ""
# Run configure.
export TF_NEED_CUDA=0
+export TF_NEED_ROCM=0
export CC_OPT_FLAGS='-mavx'
export PYTHON_BIN_PATH=$(which python2)
yes "" | $PYTHON_BIN_PATH configure.py
diff --git a/tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh
index 9ae5fc6bea..06798adc03 100755
--- a/tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh
+++ b/tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh
@@ -26,6 +26,7 @@ source "${SCRIPT_DIR}/../builds/libtensorflow.sh"
export PYTHON_BIN_PATH="/usr/bin/python"
export TF_NEED_HDFS=0
export TF_NEED_CUDA=0
+export TF_NEED_ROCM=0
export TF_NEED_OPENCL_SYCL=0
export TF_NEED_MKL=0
export COMPUTECPP_PATH="/usr/local"
diff --git a/tensorflow/tools/ci_build/osx/libtensorflow_gpu.sh b/tensorflow/tools/ci_build/osx/libtensorflow_gpu.sh
index d95fcdeb85..95f1992d7d 100755
--- a/tensorflow/tools/ci_build/osx/libtensorflow_gpu.sh
+++ b/tensorflow/tools/ci_build/osx/libtensorflow_gpu.sh
@@ -27,6 +27,7 @@ export TF_NEED_CUDA=1
export LD_LIBRARY_PATH="/usr/local/cuda/lib:/usr/local/cuda/extras/CUPTI/lib:${LD_LIBRARY_PATH}"
export PYTHON_BIN_PATH="/usr/bin/python"
export TF_NEED_HDFS=0
+export TF_NEED_ROCM=0
export TF_NEED_OPENCL_SYCL=0
export TF_NEED_MKL=0
export COMPUTECPP_PATH="/usr/local"
diff --git a/tensorflow/contrib/tensorboard/plugins/trace/__init__.py b/tensorflow/tools/ci_build/osx/libtensorflow_rocm.sh
index 2c99f4077e..aeabc0e39e 100644..100755
--- a/tensorflow/contrib/tensorboard/plugins/trace/__init__.py
+++ b/tensorflow/tools/ci_build/osx/libtensorflow_rocm.sh
@@ -1,3 +1,4 @@
+#!/usr/bin/env bash
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,13 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Public API for the Trace plugin."""
+#
+# Script to produce binary release of libtensorflow (C API, Java jars etc.).
+
+set -ex
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+# See comments at the top of this file for details.
+source "${SCRIPT_DIR}/../builds/libtensorflow.sh"
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+# Configure script
+export TF_NEED_ROCM=1
+export PYTHON_BIN_PATH="/usr/bin/python"
+export TF_NEED_GCP=0
+export TF_NEED_HDFS=0
+export TF_NEED_CUDA=0
+export TF_NEED_OPENCL_SYCL=0
+export TF_NEED_MKL=0
+export COMPUTECPP_PATH="/usr/local"
-# pylint: disable=wildcard-import
-from tensorflow.contrib.tensorboard.plugins.trace.trace import *
-from tensorflow.contrib.tensorboard.plugins.trace.trace_info_pb2 import *
-# pylint: enable=wildcard-import
+export PATH="/usr/local/cuda/bin:/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"
+build_libtensorflow_tarball "-gpu-darwin-$(uname -m)"
diff --git a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
index 28d5565b98..34847e637a 100644
--- a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
@@ -122,7 +122,7 @@ fi
PIP_NAME=$(ls ${PY_TEST_DIR}/tensorflow_gpu-*.whl)
reinstall_tensorflow_pip ${PIP_NAME}
-TF_GPU_COUNT=${TF_GPU_COUNT:-8}
+TF_GPU_COUNT=${TF_GPU_COUNT:-4}
# Define no_tensorflow_py_deps=true so that every py_test has no deps anymore,
# which will result testing system installed tensorflow
diff --git a/tensorflow/tools/ci_build/xla/linux/rocm/run_py3.sh b/tensorflow/tools/ci_build/xla/linux/rocm/run_py3.sh
new file mode 100755
index 0000000000..a0de128020
--- /dev/null
+++ b/tensorflow/tools/ci_build/xla/linux/rocm/run_py3.sh
@@ -0,0 +1,41 @@
+#!/usr/bin/env bash
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+set -e
+set -x
+
+N_JOBS=$(grep -c ^processor /proc/cpuinfo)
+
+echo ""
+echo "Bazel will use ${N_JOBS} concurrent job(s)."
+echo ""
+
+# Run configure.
+export PYTHON_BIN_PATH=`which python3`
+
+export TF_NEED_ROCM=1
+
+yes "" | $PYTHON_BIN_PATH configure.py
+echo "build --distinct_host_configuration=false" >> .tf_configure.bazelrc
+
+bazel clean
+# Run bazel test command. Double test timeouts to avoid flakes.
+bazel test --config=rocm --test_tag_filters=-no_gpu,-benchmark-test,-no_oss -k \
+ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
+ --build_tests_only --test_output=errors --local_test_jobs=1 \
+ --config=xla -- \
+ //tensorflow/compiler/...
diff --git a/tensorflow/tools/compatibility/testdata/test_file_v0_11.py b/tensorflow/tools/compatibility/testdata/test_file_v0_11.py
index 35a74c9664..68ba7a2630 100644
--- a/tensorflow/tools/compatibility/testdata/test_file_v0_11.py
+++ b/tensorflow/tools/compatibility/testdata/test_file_v0_11.py
@@ -94,7 +94,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
self.assertAllClose(
tf.reduce_logsumexp(a, [0, 1]).eval(), 6.45619344711)
self.assertAllEqual(
- tf.expand_dims([[1, 2], [3, 4]], dim=1).eval(),
+ tf.expand_dims([[1, 2], [3, 4]], axis=1).eval(),
[[[1, 2]], [[3, 4]]])
def testArgMinMax(self):
diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py
index 38216ce9b1..53c546b10c 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_v2.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py
@@ -120,10 +120,18 @@ Simple usage:
report_filename = args.report_filename
files_processed = 0
if args.input_file:
+ if not args.output_file:
+ raise ValueError(
+ "--outfile=<output file> argument is required when converting a "
+ "single file.")
files_processed, report_text, errors = upgrade.process_file(
args.input_file, args.output_file)
files_processed = 1
elif args.input_tree:
+ if not args.output_tree:
+ raise ValueError(
+ "--outtree=<output directory> argument is required when converting a "
+ "file tree.")
files_processed, report_text, errors = upgrade.process_tree(
args.input_tree, args.output_tree, args.copy_other_files)
else:
diff --git a/tensorflow/tools/dist_test/README.md b/tensorflow/tools/dist_test/README.md
index 228d5ee35d..f8ed74aaf7 100644
--- a/tensorflow/tools/dist_test/README.md
+++ b/tensorflow/tools/dist_test/README.md
@@ -23,7 +23,7 @@ You can test specify version of TensorFlow:
./local_test.sh ${whl_file_url}
```
-For example, you can find these TensorFlow python package URLs from [here](https://www.tensorflow.org/install/install_linux#the_url_of_the_tensorflow_python_package) for Ubuntu.
+For example, you can find these TensorFlow python package URLs from [here](https://www.tensorflow.org/install/pip) for Ubuntu.
**2) Launch a remote k8s cluster on Google Kubernetes Engine (GKE) and run the
test suite on it**
diff --git a/tensorflow/tools/dist_test/server/BUILD b/tensorflow/tools/dist_test/server/BUILD
index 003a19a9ab..3aa53a5615 100644
--- a/tensorflow/tools/dist_test/server/BUILD
+++ b/tensorflow/tools/dist_test/server/BUILD
@@ -8,6 +8,7 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow:tensorflow.bzl", "py_binary")
py_binary(
name = "grpc_tensorflow_server",
diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel
index 39e7bc8b66..c741e8ad0c 100644
--- a/tensorflow/tools/docker/Dockerfile.devel
+++ b/tensorflow/tools/docker/Dockerfile.devel
@@ -78,7 +78,7 @@ RUN mkdir /bazel && \
# Download and build TensorFlow.
WORKDIR /tensorflow
-RUN git clone --branch=r1.10 --depth=1 https://github.com/tensorflow/tensorflow.git .
+RUN git clone --branch=r1.11 --depth=1 https://github.com/tensorflow/tensorflow.git .
# TODO(craigcitro): Don't install the pip package, since it makes it
# more difficult to experiment with local changes. Instead, just add
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index e487779e7a..f544725af4 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -100,7 +100,7 @@ RUN mkdir /bazel && \
# Download and build TensorFlow.
WORKDIR /tensorflow
-RUN git clone --branch=r1.10 --depth=1 https://github.com/tensorflow/tensorflow.git .
+RUN git clone --branch=r1.11 --depth=1 https://github.com/tensorflow/tensorflow.git .
# Configure the build for our CUDA configuration.
ENV CI_BUILD_PYTHON python
diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl b/tensorflow/tools/docker/Dockerfile.devel-mkl
index 371451d2aa..db7c701289 100755
--- a/tensorflow/tools/docker/Dockerfile.devel-mkl
+++ b/tensorflow/tools/docker/Dockerfile.devel-mkl
@@ -3,7 +3,7 @@ FROM ubuntu:16.04
LABEL maintainer="Clayne Robison <clayne.b.robison@intel.com>"
# These parameters can be overridden by parameterized_docker_build.sh
-ARG TF_BUILD_VERSION=r1.10
+ARG TF_BUILD_VERSION=r1.11
ARG PYTHON="python"
ARG PYTHON3_DEV=""
ARG WHL_DIR="/tmp/pip"
diff --git a/tensorflow/tools/docker/jupyter_notebook_config.py b/tensorflow/tools/docker/jupyter_notebook_config.py
index 05dcefb099..4449e3501f 100644
--- a/tensorflow/tools/docker/jupyter_notebook_config.py
+++ b/tensorflow/tools/docker/jupyter_notebook_config.py
@@ -16,7 +16,7 @@ import os
from IPython.lib import passwd
c = c # pylint:disable=undefined-variable
-c.NotebookApp.ip = '*'
+c.NotebookApp.ip = '0.0.0.0' # https://github.com/jupyter/notebook/issues/3946
c.NotebookApp.port = int(os.getenv('PORT', 8888))
c.NotebookApp.open_browser = False
diff --git a/tensorflow/tools/docker/parameterized_docker_build.sh b/tensorflow/tools/docker/parameterized_docker_build.sh
index 448a3a7647..570aa8278c 100755
--- a/tensorflow/tools/docker/parameterized_docker_build.sh
+++ b/tensorflow/tools/docker/parameterized_docker_build.sh
@@ -244,7 +244,7 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then
if [[ "${TF_DOCKER_BUILD_TYPE}" == "gpu" ]]; then
export TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS=\
- "${TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS} -e TF_CUDA_COMPUTE_CAPABILITIES=3.0,3.5,5.2"
+ "${TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS} -e TF_CUDA_COMPUTE_CAPABILITIES=3.0,3.5,5.2,6.0"
fi
pushd "${SCRIPT_DIR}/../../../"
diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD
index 4f7efe193f..2a858b4fd6 100644
--- a/tensorflow/tools/docs/BUILD
+++ b/tensorflow/tools/docs/BUILD
@@ -37,6 +37,7 @@ py_library(
name = "doc_controls",
srcs = ["doc_controls.py"],
srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
)
py_test(
@@ -91,9 +92,10 @@ py_binary(
":parser",
":pretty_docs",
":py_guide_parser",
- "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
+ "//tensorflow/python:util",
"//tensorflow/tools/common:public_api",
"//tensorflow/tools/common:traverse",
+ "@six_archive//:six",
],
)
diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py
index 1cd9cb7ca9..77a3ca2052 100644
--- a/tensorflow/tools/docs/generate_lib.py
+++ b/tensorflow/tools/docs/generate_lib.py
@@ -453,7 +453,11 @@ def update_id_tags_inplace(src_dir):
EXCLUDED = set(['__init__.py', 'OWNERS', 'README.txt'])
-def replace_refs(src_dir, output_dir, reference_resolver, file_pattern='*.md'):
+def replace_refs(src_dir,
+ output_dir,
+ reference_resolver,
+ file_pattern='*.md',
+ api_docs_relpath='api_docs'):
"""Fix @{} references in all files under `src_dir` matching `file_pattern`.
A matching directory structure, with the modified files is
@@ -472,12 +476,13 @@ def replace_refs(src_dir, output_dir, reference_resolver, file_pattern='*.md'):
reference_resolver: A `parser.ReferenceResolver` to make the replacements.
file_pattern: Only replace references in files matching file_patters,
using fnmatch. Non-matching files are copied unchanged.
+ api_docs_relpath: Relative-path string to the api_docs, from the src_dir.
"""
# Iterate through all the source files and process them.
for dirpath, _, filenames in os.walk(src_dir):
+ depth = os.path.relpath(src_dir, start=dirpath)
# How to get from `dirpath` to api_docs/python/
- relative_path_to_root = os.path.relpath(
- path=os.path.join(src_dir, 'api_docs/python'), start=dirpath)
+ relative_path_to_root = os.path.join(depth, api_docs_relpath, 'python')
# Make the directory under output_dir.
new_dir = os.path.join(output_dir,
@@ -497,7 +502,8 @@ def replace_refs(src_dir, output_dir, reference_resolver, file_pattern='*.md'):
full_out_path = os.path.join(output_dir, suffix)
# Copy files that do not match the file_pattern, unmodified.
if not fnmatch.fnmatch(base_name, file_pattern):
- shutil.copyfile(full_in_path, full_out_path)
+ if full_in_path != full_out_path:
+ shutil.copyfile(full_in_path, full_out_path)
continue
with open(full_in_path, 'rb') as f:
diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD
index b450bc42c5..b9f4902639 100644
--- a/tensorflow/tools/lib_package/BUILD
+++ b/tensorflow/tools/lib_package/BUILD
@@ -125,6 +125,7 @@ genrule(
"@gemmlowp//:LICENSE",
"@gif_archive//:COPYING",
"@highwayhash//:LICENSE",
+ "@icu//:icu4c/LICENSE",
"@jpeg//:LICENSE.md",
"@llvm//:LICENSE.TXT",
"@lmdb//:LICENSE",
@@ -136,16 +137,6 @@ genrule(
"@snappy//:COPYING",
"@zlib_archive//:zlib.h",
] + select({
- "//tensorflow:with_aws_support": [
- "@aws//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_gcp_support": [
- "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
"//tensorflow:with_jemalloc_linux_x86_64": [
"@jemalloc//:COPYING",
],
@@ -170,7 +161,14 @@ genrule(
"@grpc//third_party/nanopb:LICENSE.txt",
"@grpc//third_party/address_sorting:LICENSE",
],
- ),
+ ) + select({
+ "//tensorflow:linux_s390x": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
+ "@aws//:LICENSE",
+ "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
+ ],
+ }),
outs = ["include/tensorflow/c/LICENSE"],
cmd = "$(location :concat_licenses.sh) $(SRCS) >$@",
tools = [":concat_licenses.sh"],
@@ -192,6 +190,7 @@ genrule(
"@gemmlowp//:LICENSE",
"@gif_archive//:COPYING",
"@highwayhash//:LICENSE",
+ "@icu//:icu4j/main/shared/licenses/LICENSE",
"@jpeg//:LICENSE.md",
"@llvm//:LICENSE.TXT",
"@lmdb//:LICENSE",
@@ -203,16 +202,6 @@ genrule(
"@snappy//:COPYING",
"@zlib_archive//:zlib.h",
] + select({
- "//tensorflow:with_aws_support": [
- "@aws//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_gcp_support": [
- "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
"//tensorflow:with_jemalloc_linux_x86_64": [
"@jemalloc//:COPYING",
],
@@ -230,7 +219,14 @@ genrule(
]) + if_mkl([
"//third_party/mkl:LICENSE",
"//third_party/mkl_dnn:LICENSE",
- ]),
+ ]) + select({
+ "//tensorflow:linux_s390x": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
+ "@aws//:LICENSE",
+ "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
+ ],
+ }),
outs = ["include/tensorflow/jni/LICENSE"],
cmd = "$(location :concat_licenses.sh) $(SRCS) >$@",
tools = [":concat_licenses.sh"],
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 50515b04a9..7d925a8fef 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -62,11 +62,11 @@ COMMON_PIP_DEPS = [
"//tensorflow/contrib/autograph:autograph",
"//tensorflow/contrib/boosted_trees:boosted_trees_pip",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
+ "//tensorflow/contrib/compiler:xla",
"//tensorflow/contrib/constrained_optimization:constrained_optimization_pip",
"//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base",
"//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
"//tensorflow/contrib/data/python/kernel_tests:test_utils",
- "//tensorflow/contrib/data/python/ops:contrib_op_loader",
"//tensorflow/contrib/eager/python/examples:examples_pip",
"//tensorflow/contrib/eager/python:evaluator",
"//tensorflow/contrib/gan:gan",
@@ -107,6 +107,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/python:meta_graph_testdata",
"//tensorflow/python:spectral_ops_test_util",
"//tensorflow/python:util_example_parser_configuration",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/debug:debug_pip",
"//tensorflow/python/eager:eager_pip",
"//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files",
@@ -114,6 +115,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/python/tools:tools_pip",
"//tensorflow/python/tools/api/generator:create_python_api",
"//tensorflow/python:test_ops",
+ "//tensorflow/python:while_v2",
"//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
]
@@ -150,6 +152,7 @@ filegroup(
"@gemmlowp//:LICENSE",
"@gif_archive//:COPYING",
"@highwayhash//:LICENSE",
+ "@icu//:icu4c/LICENSE",
"@jpeg//:LICENSE.md",
"@lmdb//:LICENSE",
"@local_config_sycl//sycl:LICENSE.text",
@@ -165,17 +168,6 @@ filegroup(
"@zlib_archive//:zlib.h",
"@org_python_pypi_backports_weakref//:LICENSE",
] + select({
- "//tensorflow:with_aws_support": [
- "@aws//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_gcp_support": [
- "@com_github_googleapis_googleapis//:LICENSE",
- "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
"//tensorflow:with_jemalloc_linux_x86_64": [
"@jemalloc//:COPYING",
],
@@ -184,11 +176,6 @@ filegroup(
],
"//conditions:default": [],
}) + select({
- "//tensorflow:with_kafka_support": [
- "@kafka//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
"//tensorflow/core/kernels:xsmm": [
"@libxsmm_archive//:LICENSE.md",
],
@@ -210,7 +197,17 @@ filegroup(
"@ngraph//:LICENSE",
"@ngraph_tf//:LICENSE",
"@nlohmann_json_lib//:LICENSE.MIT",
- ]) + tf_additional_license_deps(),
+ "@tbb//:LICENSE",
+ ]) + tf_additional_license_deps() + select({
+ "//tensorflow:linux_s390x": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
+ "@aws//:LICENSE",
+ "@com_github_googleapis_googleapis//:LICENSE",
+ "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
+ "@kafka//:LICENSE",
+ ],
+ }),
)
sh_binary(
diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py
index bfc007bc39..c6ef82ccdc 100644
--- a/tensorflow/tools/pip_package/pip_smoke_test.py
+++ b/tensorflow/tools/pip_package/pip_smoke_test.py
@@ -90,6 +90,7 @@ BLACKLIST = [
"//tensorflow/contrib/lite/python:interpreter.py",
"//tensorflow/contrib/lite/python:interpreter_test.py",
"//tensorflow/contrib/ffmpeg:test_data",
+ "//tensorflow/contrib/fused_conv:fused_conv2d_bias_activation_op_test_base",
"//tensorflow/contrib/hadoop:test_data",
"//tensorflow/contrib/factorization/examples:mnist",
"//tensorflow/contrib/factorization/examples:mnist.py",
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 3102239a19..b95e1f5c87 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -45,7 +45,7 @@ DOCLINES = __doc__.split('\n')
# This version string is semver compatible, but incompatible with pip.
# For pip, we will remove all '-' characters from this string, and use the
# result for pip.
-_VERSION = '1.10.0'
+_VERSION = '1.11.0-rc1'
REQUIRED_PACKAGES = [
'absl-py >= 0.1.6',
@@ -57,7 +57,7 @@ REQUIRED_PACKAGES = [
'six >= 1.10.0',
'protobuf >= 3.6.0',
'setuptools <= 39.1.0',
- 'tensorboard >= 1.10.0, < 1.11.0',
+ 'tensorboard >= 1.11.0, < 1.12.0',
'termcolor >= 1.1.0',
]
@@ -86,7 +86,7 @@ else:
if 'tf_nightly' in project_name:
for i, pkg in enumerate(REQUIRED_PACKAGES):
if 'tensorboard' in pkg:
- REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.11.0a0, < 1.12.0a0'
+ REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.12.0a0, < 1.13.0a0'
break
# weakref.finalize and enum were introduced in Python 3.4
diff --git a/tensorflow/tools/quantization/BUILD b/tensorflow/tools/quantization/BUILD
deleted file mode 100644
index 17443a8617..0000000000
--- a/tensorflow/tools/quantization/BUILD
+++ /dev/null
@@ -1,78 +0,0 @@
-# Description:
-# Utilities for quantizing TensorFlow graphs to lower bit depths.
-
-package(default_visibility = ["//visibility:public"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-load("//tensorflow:tensorflow.bzl", "py_test")
-
-py_library(
- name = "quantize_graph_lib",
- srcs = ["quantize_graph.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:graph_util",
- "//tensorflow/python:platform",
- "//tensorflow/python:session",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python:tensor_util",
- "//third_party/py/numpy",
- ],
-)
-
-py_binary(
- name = "quantize_graph",
- srcs = ["quantize_graph.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python", # TODO(b/34059704): remove when fixed
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:graph_util",
- "//tensorflow/python:platform",
- "//tensorflow/python:tensor_util",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "quantize_graph_test",
- size = "small",
- srcs = ["quantize_graph_test.py"],
- srcs_version = "PY2AND3",
- tags = ["nomsan"], # http://b/32242946
- deps = [
- ":quantize_graph",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:graph_util",
- "//tensorflow/python:platform",
- "//third_party/py/numpy",
- ],
-)
-
-py_binary(
- name = "graph_to_dot",
- srcs = ["graph_to_dot.py"],
- main = "graph_to_dot.py",
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:platform",
- ],
-)
diff --git a/tensorflow/tools/quantization/graph_to_dot.py b/tensorflow/tools/quantization/graph_to_dot.py
deleted file mode 100644
index 81d6aa62c8..0000000000
--- a/tensorflow/tools/quantization/graph_to_dot.py
+++ /dev/null
@@ -1,68 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Converts a GraphDef file into a DOT format suitable for visualization.
-
-This script takes a GraphDef representing a network, and produces a DOT file
-that can then be visualized by GraphViz tools like dot and xdot.
-
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import re
-
-from google.protobuf import text_format
-
-from tensorflow.core.framework import graph_pb2
-from tensorflow.python.platform import app
-from tensorflow.python.platform import flags
-from tensorflow.python.platform import gfile
-
-FLAGS = flags.FLAGS
-
-flags.DEFINE_string("graph", "", """TensorFlow 'GraphDef' file to load.""")
-flags.DEFINE_bool("input_binary", True,
- """Whether the input files are in binary format.""")
-flags.DEFINE_string("dot_output", "", """Where to write the DOT output.""")
-
-
-def main(unused_args):
- if not gfile.Exists(FLAGS.graph):
- print("Input graph file '" + FLAGS.graph + "' does not exist!")
- return -1
-
- graph = graph_pb2.GraphDef()
- with open(FLAGS.graph, "r") as f:
- if FLAGS.input_binary:
- graph.ParseFromString(f.read())
- else:
- text_format.Merge(f.read(), graph)
-
- with open(FLAGS.dot_output, "wb") as f:
- print("digraph graphname {", file=f)
- for node in graph.node:
- output_name = node.name
- print(" \"" + output_name + "\" [label=\"" + node.op + "\"];", file=f)
- for input_full_name in node.input:
- parts = input_full_name.split(":")
- input_name = re.sub(r"^\^", "", parts[0])
- print(" \"" + input_name + "\" -> \"" + output_name + "\";", file=f)
- print("}", file=f)
- print("Created DOT file '" + FLAGS.dot_output + "'.")
-
-
-if __name__ == "__main__":
- app.run()
diff --git a/tensorflow/tools/quantization/quantize_graph.py b/tensorflow/tools/quantization/quantize_graph.py
deleted file mode 100644
index 3acb532263..0000000000
--- a/tensorflow/tools/quantization/quantize_graph.py
+++ /dev/null
@@ -1,1302 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Transforms a float-trained graph into an equivalent quantized version.
-
-An example of command-line usage is:
-bazel build tensorflow/tools/quantization:quantize_graph \
-&& bazel-bin/tensorflow/tools/quantization/quantize_graph \
---input=tensorflow_inception_graph.pb
---output_node_names="softmax2" --print_nodes --output=/tmp/quantized_graph.pb \
---mode=eightbit --logtostderr
-
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import collections
-import re
-import numpy as np
-
-from tensorflow.core.framework import attr_value_pb2
-from tensorflow.core.framework import graph_pb2
-from tensorflow.core.framework import node_def_pb2
-from tensorflow.python.client import session
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import graph_util
-from tensorflow.python.framework import importer
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import app
-from tensorflow.python.platform import flags as flags_lib
-from tensorflow.python.platform import gfile
-
-flags = flags_lib
-FLAGS = flags.FLAGS
-
-flags.DEFINE_boolean("print_nodes", False, """Lists all nodes in the model.""")
-flags.DEFINE_string("input", "", """TensorFlow 'GraphDef' file to load.""")
-flags.DEFINE_string("output_node_names", "",
- """Output node names, comma separated.""")
-flags.DEFINE_string("output", "", """File to save the output graph to.""")
-flags.DEFINE_integer("bitdepth", 8,
- """How many bits to quantize the graph to.""")
-flags.DEFINE_string("mode", "round",
- """What transformation to apply (round, quantize,"""
- """ eightbit, weights, or weights_rounded).""")
-flags.DEFINE_string("test_input_dims", "1,224,224,3",
- """The size of the input tensor to use when testing a"""
- """ graph loaded from a file.""")
-flags.DEFINE_boolean("strip_redundant_quantization", True,
- """Removes redundant dequantize/quantize pairs.""")
-flags.DEFINE_boolean("quantized_input", False,
- "If true, assume Placeholders are quantized with values "
- "covering [--quantized_input_min,--quantized_input_max]. "
- "Only supported when --mode=eightbit")
-flags.DEFINE_float("quantized_input_min", 0,
- "The minimum of the actual input range when "
- "--quantized_input")
-flags.DEFINE_float("quantized_input_max", 1,
- "The maximum of the actual input range when "
- "--quantized_input")
-flags.DEFINE_float(
- "quantized_fallback_min", None,
- "The fallback 'min' value to use for layers which lack min-max "
- "information. Note: this should be considered a coarse tool just good "
- "enough for experimentation purposes, since graphs quantized in this way "
- "would be very inaccurate.")
-flags.DEFINE_float(
- "quantized_fallback_max", None,
- "The fallback 'max' value to use for layers which lack min-max "
- "information. Note: this should be considered a coarse tool just good "
- "enough for experimentation purposes, since graphs quantized in this way "
- "would be very inaccurate.")
-
-
-def print_input_nodes(current_node, nodes_map, indent, already_visited):
- print(" " * indent + current_node.op + ":" + current_node.name)
- already_visited[current_node.name] = True
- for input_node_name in current_node.input:
- if input_node_name in already_visited:
- continue
- input_node = nodes_map[input_node_name]
- print_input_nodes(input_node, nodes_map, indent + 1, already_visited)
-
-
-def create_node(op, name, inputs):
- new_node = node_def_pb2.NodeDef()
- new_node.op = op
- new_node.name = name
- for input_name in inputs:
- new_node.input.extend([input_name])
- return new_node
-
-
-def create_constant_node(name, value, dtype, shape=None):
- node = create_node("Const", name, [])
- set_attr_dtype(node, "dtype", dtype)
- set_attr_tensor(node, "value", value, dtype, shape)
- return node
-
-
-def copy_attr(node, key, attr_value):
- try:
- node.attr[key].CopyFrom(attr_value)
- except KeyError:
- pass
-
-
-def set_attr_dtype(node, key, value):
- try:
- node.attr[key].CopyFrom(
- attr_value_pb2.AttrValue(type=value.as_datatype_enum))
- except KeyError:
- pass
-
-
-def set_attr_shape(node, key, value):
- try:
- node.attr[key].CopyFrom(
- attr_value_pb2.AttrValue(shape=tensor_shape.as_shape(value).as_proto()))
- except KeyError:
- pass
-
-
-def set_attr_tensor(node, key, value, dtype, shape=None):
- try:
- node.attr[key].CopyFrom(
- attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
- value, dtype=dtype, shape=shape)))
- except KeyError:
- pass
-
-
-def set_attr_string(node, key, value):
- try:
- node.attr[key].CopyFrom(attr_value_pb2.AttrValue(s=value))
- except KeyError:
- pass
-
-
-def set_attr_int_list(node, key, value):
- list_value = attr_value_pb2.AttrValue.ListValue(i=value)
- try:
- node.attr[key].CopyFrom(attr_value_pb2.AttrValue(list=list_value))
- except KeyError:
- pass
-
-
-def set_attr_bool(node, key, value):
- try:
- node.attr[key].CopyFrom(attr_value_pb2.AttrValue(b=value))
- except KeyError:
- pass
-
-
-def set_attr_int(node, key, value):
- try:
- node.attr[key].CopyFrom(attr_value_pb2.AttrValue(i=value))
- except KeyError:
- pass
-
-
-def set_attr_float(node, key, value):
- try:
- node.attr[key].CopyFrom(attr_value_pb2.AttrValue(f=value))
- except KeyError:
- pass
-
-
-def node_name_from_input(node_name):
- """Strips off ports and other decorations to get the underlying node name."""
- if node_name.startswith("^"):
- node_name = node_name[1:]
- m = re.search(r"(.*):\d+$", node_name)
- if m:
- node_name = m.group(1)
- return node_name
-
-
-def ensure_tensor_name_has_port(node_name):
- """Makes sure that a tensor name has :0 if no explicit port exists."""
- m = re.search(r"(.*):\d+$", node_name)
- if m:
- name_with_port = node_name
- else:
- name_with_port = node_name + ":0"
- return name_with_port
-
-
-def unique_node_name_from_input(node_name):
- """Replaces invalid characters in input names to get a unique node name."""
- return node_name.replace(":", "__port__").replace("^", "__hat__")
-
-
-def quantize_array(arr, num_buckets):
- """Quantizes a numpy array.
-
- This function maps each scalar in arr to the center of one of num_buckets
- buckets. For instance,
- quantize_array([0, 0.3, 0.6, 1], 2) => [0.25, 0.25, 0.75, 0.75]
-
- Args:
- arr: The numpy array to quantize.
- num_buckets: The number of buckets to map "var" to.
- Returns:
- The quantized numpy array.
- Raises:
- ValueError: when num_buckets < 1.
- """
- if num_buckets < 1:
- raise ValueError("num_buckets must be >= 1")
- arr_max = arr.max()
- arr_min = arr.min()
- if arr_max == arr_min:
- return arr
- bucket_width = (arr_max - arr_min) / num_buckets
- # Map scalars to bucket indices. Take special care of max(arr).
- bucket_indices = np.floor((arr - arr_min) / bucket_width)
- bucket_indices[bucket_indices == num_buckets] = num_buckets - 1
- # Map each scalar to the center of a bucket.
- arr = arr_min + bucket_width * (bucket_indices + 0.5)
- return arr
-
-
-def quantize_weight_rounded(input_node):
- """Returns a replacement node for input_node containing bucketed floats."""
- input_tensor = input_node.attr["value"].tensor
- tensor_value = tensor_util.MakeNdarray(input_tensor)
- shape = input_tensor.tensor_shape
- # Currently, the parameter FLAGS.bitdepth is used to compute the
- # number of buckets as 1 << FLAGS.bitdepth, meaning the number of
- # buckets can only be a power of 2.
- # This could be fixed by introducing a new parameter, num_buckets,
- # which would allow for more flexibility in chosing the right model
- # size/accuracy tradeoff. But I didn't want to add more parameters
- # to this script than absolutely necessary.
- num_buckets = 1 << FLAGS.bitdepth
- tensor_value_rounded = quantize_array(tensor_value, num_buckets)
- tensor_shape_list = tensor_util.TensorShapeProtoToList(shape)
- return [
- create_constant_node(
- input_node.name,
- tensor_value_rounded,
- dtypes.float32,
- shape=tensor_shape_list)
- ]
-
-
-def quantize_weight_eightbit(input_node, quantization_mode):
- """Returns replacement nodes for input_node using the Dequantize op."""
- base_name = input_node.name + "_"
- quint8_const_name = base_name + "quint8_const"
- min_name = base_name + "min"
- max_name = base_name + "max"
- float_tensor = tensor_util.MakeNdarray(input_node.attr["value"].tensor)
- min_value = np.min(float_tensor.flatten())
- max_value = np.max(float_tensor.flatten())
- # Make sure that the range includes zero.
- if min_value > 0.0:
- min_value = 0.0
- # min_value == max_value is a tricky case. It can occur for general
- # tensors, and of course for scalars. The quantized ops cannot deal
- # with this case, so we set max_value to something else.
- # It's a tricky question what is the numerically best solution to
- # deal with this degeneracy.
- # TODO(petewarden): Better use a tolerance than a hard comparison?
- if min_value == max_value:
- if abs(min_value) < 0.000001:
- max_value = min_value + 1.0
- elif min_value > 0:
- max_value = 2 * min_value
- else:
- max_value = min_value / 2.0
-
- sess = session.Session()
- with sess.as_default():
- quantize_op = array_ops.quantize_v2(
- float_tensor,
- min_value,
- max_value,
- dtypes.quint8,
- mode=quantization_mode)
- quint8_tensor = quantize_op[0].eval()
- shape = tensor_util.TensorShapeProtoToList(input_node.attr["value"]
- .tensor.tensor_shape)
- quint8_const_node = create_constant_node(
- quint8_const_name, quint8_tensor, dtypes.quint8, shape=shape)
- min_node = create_constant_node(min_name, min_value, dtypes.float32)
- max_node = create_constant_node(max_name, max_value, dtypes.float32)
- dequantize_node = create_node("Dequantize", input_node.name,
- [quint8_const_name, min_name, max_name])
- set_attr_dtype(dequantize_node, "T", dtypes.quint8)
- set_attr_string(dequantize_node, "mode", quantization_mode)
- return [quint8_const_node, min_node, max_node, dequantize_node]
-
-
-EightbitizeRecursionState = collections.namedtuple(
- "EightbitizeRecursionState",
- ["already_visited", "output_node_stack", "merged_with_fake_quant"])
-
-
-class GraphRewriter(object):
- """Takes a float graph, and rewrites it in quantized form."""
-
- def __init__(self,
- input_graph,
- mode,
- quantized_input_range,
- fallback_quantization_range=None):
- """Sets up the class to rewrite a float graph.
-
- Args:
- input_graph: A float graph to transform.
- mode: A string controlling how quantization is performed -
- round, quantize, eightbit, or weights.
- quantized_input_range: if set, assume the input is
- quantized and represents the range
- [quantized_input_range[0], quantized_input_range[1]]
- fallback_quantization_range: if set, then for nodes where the quantization
- range can't be inferred from the graph, use the range
- [fallback_quantization_range[0], fallback_quantization_range[1]) instead
- of using a RequantizationRange node in the graph.
-
- Raises:
- ValueError: Two nodes with the same name were found in the graph.
- """
- self.input_graph = input_graph
- self.nodes_map = self.create_nodes_map(input_graph)
- self.output_graph = None
- self.mode = mode
- self.final_node_renames = {}
- if quantized_input_range:
- self.input_range = (quantized_input_range[0], quantized_input_range[1])
- if self.input_range[0] >= self.input_range[1]:
- raise ValueError("Invalid quantized_input_range: [%s,%s]" %
- self.input_range)
- if self.mode != "eightbit":
- raise ValueError(
- "quantized_input_range can only be specified in eightbit mode")
- else:
- self.input_range = None
-
- if fallback_quantization_range:
- self.fallback_quantization_range = [
- fallback_quantization_range[0], fallback_quantization_range[1]
- ]
- if (self.fallback_quantization_range[0] >=
- self.fallback_quantization_range[1]):
- raise ValueError("Invalid fallback_quantization_range: [%s,%s]" %
- self.fallback_quantization_range)
- if self.mode != "eightbit":
- raise ValueError("fallback_quantization_range can only be "
- "specified in eightbit mode")
- else:
- self.fallback_quantization_range = None
-
- # Data that is valid only during the recursive call to rewrite the graph.
- self.state = None
-
- def create_nodes_map(self, graph):
- """Builds a mapping of node names to their defs from the graph."""
- nodes_map = {}
- for node in graph.node:
- if node.name not in nodes_map.keys():
- nodes_map[node.name] = node
- else:
- raise ValueError("Duplicate node names detected.")
- return nodes_map
-
- def rewrite(self, output_node_names):
- """Triggers rewriting of the float graph.
-
- Args:
- output_node_names: A list of names of the nodes that produce the final
- results.
-
- Returns:
- A quantized version of the float graph.
- """
- self.output_graph = graph_pb2.GraphDef()
- output_nodes = [
- self.nodes_map[output_node_name]
- for output_node_name in output_node_names
- ]
- if self.mode == "round":
- self.already_visited = {}
- for output_node in output_nodes:
- self.round_nodes_recursively(output_node)
- elif self.mode == "quantize":
- self.already_visited = {}
- self.already_quantized = {}
- for output_node in output_nodes:
- self.quantize_nodes_recursively(output_node)
- elif self.mode == "eightbit":
- self.set_input_graph(graph_util.remove_training_nodes(
- self.input_graph, protected_nodes=output_node_names))
- output_nodes = [
- self.nodes_map[output_node_name]
- for output_node_name in output_node_names
- ]
-
- self.state = EightbitizeRecursionState(
- already_visited={}, output_node_stack=[], merged_with_fake_quant={})
- for output_node in output_nodes:
- self.eightbitize_nodes_recursively(output_node)
- self.state = None
- if self.input_range:
- self.add_output_graph_node(
- create_constant_node("quantized_input_min_value", self.input_range[
- 0], dtypes.float32, []))
- self.add_output_graph_node(
- create_constant_node("quantized_input_max_value", self.input_range[
- 1], dtypes.float32, []))
- if self.fallback_quantization_range:
- self.add_output_graph_node(
- create_constant_node("fallback_quantization_min_value",
- self.fallback_quantization_range[0],
- dtypes.float32, []))
- self.add_output_graph_node(
- create_constant_node("fallback_quantization_max_value",
- self.fallback_quantization_range[1],
- dtypes.float32, []))
- if FLAGS.strip_redundant_quantization:
- self.output_graph = self.remove_redundant_quantization(
- self.output_graph)
- self.remove_dead_nodes(output_node_names)
- self.apply_final_node_renames()
- elif self.mode == "weights":
- self.output_graph = self.quantize_weights(self.input_graph,
- b"MIN_COMBINED")
- self.remove_dead_nodes(output_node_names)
- elif self.mode == "weights_rounded":
- self.output_graph = self.quantize_weights(self.input_graph, self.mode)
- self.remove_dead_nodes(output_node_names)
- else:
- print("Bad mode - " + self.mode + ".")
- return self.output_graph
-
- def round_nodes_recursively(self, current_node):
- """The entry point for simple rounding quantization."""
- if (current_node.name in self.already_visited
- ) and self.already_visited[current_node.name]:
- return
- self.already_visited[current_node.name] = True
- for input_node_name in current_node.input:
- input_node_name = node_name_from_input(input_node_name)
- input_node = self.nodes_map[input_node_name]
- self.round_nodes_recursively(input_node)
- nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
- if any(current_node.op in s for s in nodes_to_quantize):
- new_node = node_def_pb2.NodeDef()
- new_node.CopyFrom(current_node)
- new_node.name = current_node.name + "_original"
- self.add_output_graph_node(new_node)
- levels = 1 << FLAGS.bitdepth
- constant_name = current_node.name + "_round_depth"
- constant_tensor = constant_op.constant(
- levels, dtype=dtypes.int32, name=constant_name)
- constant_node = constant_tensor.op.node_def
- self.add_output_graph_node(constant_node)
- quantize_node = node_def_pb2.NodeDef()
- quantize_node.op = "RoundToSteps"
- quantize_node.name = current_node.name
- quantize_node.input.extend([current_node.name + "_original"])
- quantize_node.input.extend([constant_node.name])
- self.add_output_graph_node(quantize_node)
- else:
- new_node = node_def_pb2.NodeDef()
- new_node.CopyFrom(current_node)
- self.add_output_graph_node(new_node)
-
- def quantize_nodes_recursively(self, current_node):
- """The entry point for quantizing nodes to eight bit and back."""
- if self.already_visited[current_node.name]:
- return
- self.already_visited[current_node.name] = True
- for input_node_name in current_node.input:
- input_node_name = node_name_from_input(input_node_name)
- input_node = self.nodes_map[input_node_name]
- self.quantize_nodes_recursively(input_node)
- nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
- if any(current_node.op in s for s in nodes_to_quantize):
- for input_name in current_node.input:
- input_name = node_name_from_input(input_name)
- input_node = self.nodes_map[input_name]
- self.quantize_node(input_node)
- self.quantize_node(current_node)
- else:
- new_node = node_def_pb2.NodeDef()
- new_node.CopyFrom(current_node)
- self.add_output_graph_node(new_node)
-
- def quantize_node(self, input_node):
- """Handles quantizing a single node."""
- input_name = input_node.name
- if input_name in self.already_quantized:
- return
- self.already_quantized[input_name] = True
- original_input_name = input_name + "_original"
- reshape_name = input_name + "_reshape"
- reshape_dims_name = input_name + "_reshape_dims"
- max_name = input_name + "_max"
- min_name = input_name + "_min"
- dims_name = input_name + "_dims"
- quantize_name = input_name + "_quantize"
- dequantize_name = input_name
- original_input_node = node_def_pb2.NodeDef()
- original_input_node.CopyFrom(input_node)
- original_input_node.name = original_input_name
- self.add_output_graph_node(original_input_node)
- reshape_dims_node = create_constant_node(reshape_dims_name, -1,
- dtypes.int32, [1])
- self.add_output_graph_node(reshape_dims_node)
- reshape_node = create_node("Reshape", reshape_name,
- [original_input_name, reshape_dims_name])
- set_attr_dtype(reshape_node, "T", dtypes.float32)
- self.add_output_graph_node(reshape_node)
- dims_node = create_constant_node(dims_name, 0, dtypes.int32, [1])
- self.add_output_graph_node(dims_node)
- max_node = create_node("Max", max_name, [reshape_name, dims_name])
- set_attr_dtype(max_node, "T", dtypes.float32)
- set_attr_bool(max_node, "keep_dims", False)
- self.add_output_graph_node(max_node)
- min_node = create_node("Min", min_name, [reshape_name, dims_name])
- set_attr_dtype(min_node, "T", dtypes.float32)
- set_attr_bool(min_node, "keep_dims", False)
- self.add_output_graph_node(min_node)
- quantize_node = create_node("Quantize", quantize_name,
- [original_input_name, min_name, max_name])
- set_attr_dtype(quantize_node, "T", dtypes.quint8)
- set_attr_string(quantize_node, "mode", b"MIN_FIRST")
- self.add_output_graph_node(quantize_node)
- dequantize_node = create_node("Dequantize", dequantize_name,
- [quantize_name, min_name, max_name])
- set_attr_dtype(dequantize_node, "T", dtypes.quint8)
- set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
- self.add_output_graph_node(dequantize_node)
-
- def should_merge_with_fake_quant_node(self):
- """Should the current node merge with self.state.output_node_stack[-1]?"""
- if not self.state.output_node_stack:
- return False
- top = self.state.output_node_stack[-1]
- return top[1] == 0 and top[0].op in ["FakeQuantWithMinMaxVars"]
-
- def should_quantize_const(self, node):
- if not self.state.output_node_stack:
- return False
- top = self.state.output_node_stack[-1]
- if not top[2]:
- return False
- dtype = dtypes.as_dtype(node.attr["dtype"].type)
- assert dtype == dtypes.float32, (
- "Failed to quantized constant %s of type %s" % (node.name, dtype))
- return True
-
- def eightbitize_nodes_recursively(self, current_node):
- """The entry point for transforming a graph into full eight bit."""
- if current_node.name in self.state.already_visited:
- if (self.should_merge_with_fake_quant_node() or
- current_node.name in self.state.merged_with_fake_quant):
- raise ValueError("Unsupported graph structure: output of node %s "
- "is processed by a FakeQuant* node and should have "
- "no other outputs.", current_node.name)
- return
- self.state.already_visited[current_node.name] = True
-
- for i, input_node_name in enumerate(current_node.input):
- quantize_input = False
- if current_node.op in ("MatMul", "Conv2D", "BiasAdd", "MaxPool",
- "AvgPool", "Relu", "Relu6",
- "BatchNormWithGlobalNormalization"):
- quantize_input = True
- elif current_node.op == "Concat" and i > 0:
- quantize_input = (
- dtypes.as_dtype(current_node.attr["T"].type) == dtypes.float32)
- elif current_node.op == "Reshape" and i == 0:
- quantize_input = (
- dtypes.as_dtype(current_node.attr["T"].type) == dtypes.float32)
-
- self.state.output_node_stack.append((current_node, i, quantize_input))
-
- input_node_name = node_name_from_input(input_node_name)
- input_node = self.nodes_map[input_node_name]
- self.eightbitize_nodes_recursively(input_node)
-
- self.state.output_node_stack.pop()
-
- if current_node.op == "MatMul":
- self.eightbitize_mat_mul_node(current_node)
- elif current_node.op == "Conv2D":
- self.eightbitize_conv_node(current_node)
- elif current_node.op == "BiasAdd":
- self.eightbitize_bias_add_node(current_node)
- elif current_node.op == "MaxPool" or current_node.op == "AvgPool":
- self.eightbitize_single_input_tensor_node(current_node,
- self.add_pool_function)
- elif current_node.op == "Relu" or current_node.op == "Relu6":
- self.eightbitize_single_input_tensor_node(current_node,
- self.add_relu_function)
- elif (current_node.op == "Concat" and
- dtypes.as_dtype(current_node.attr["T"].type) == dtypes.float32):
- self.eightbitize_concat_node(current_node)
- elif current_node.op == "BatchNormWithGlobalNormalization":
- self.eightbitize_batch_norm_node(current_node)
- elif (current_node.op == "Reshape" and
- dtypes.as_dtype(current_node.attr["T"].type) == dtypes.float32):
- self.eightbitize_reshape_node(current_node)
- elif (self.input_range and
- current_node.op in ("Placeholder", "PlaceholderV2")):
- self.eightbitize_placeholder_node(current_node)
- elif current_node.op == "FakeQuantWithMinMaxVars":
- # It will have been merged into the underlying node.
- pass
- elif current_node.op == "Const":
- if self.should_quantize_const(current_node):
- for n in quantize_weight_eightbit(current_node, b"MIN_FIRST"):
- self.add_output_graph_node(n)
- else:
- new_node = node_def_pb2.NodeDef()
- new_node.CopyFrom(current_node)
- self.add_output_graph_node(new_node)
-
- ###################################################################
- # Note: if more cases are added here, you may need to update the op
- # name lists in the loop over children at the start of the function.
- ###################################################################
- else:
- new_node = node_def_pb2.NodeDef()
- new_node.CopyFrom(current_node)
- self.add_output_graph_node(new_node)
-
- if (self.should_merge_with_fake_quant_node() and
- current_node.name not in self.state.merged_with_fake_quant):
- raise ValueError(
- "FakeQuant* node %s failed to merge with node %s of type %s" %
- (self.state.output_node_stack[-1][0], current_node.name,
- current_node.op))
-
- def add_eightbit_prologue_nodes(self, original_node):
- """Adds input conversion nodes to handle quantizing the underlying node."""
- namespace_prefix = original_node.name + "_eightbit"
- reshape_dims_name, reduction_dims_name = self.add_common_quantization_nodes(
- namespace_prefix)
- input_names = []
- min_max_names = []
- for original_input_name in original_node.input:
- quantize_input_name, min_input_name, max_input_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_input_name,
- reshape_dims_name,
- reduction_dims_name))
- input_names.append(quantize_input_name)
- min_max_names.append(min_input_name)
- min_max_names.append(max_input_name)
- all_input_names = []
- all_input_names.extend(input_names)
- all_input_names.extend(min_max_names)
- return all_input_names
-
- def add_common_quantization_nodes(self, namespace_prefix):
- """Builds constant nodes needed for quantization of inputs."""
- reshape_dims_name = namespace_prefix + "_reshape_dims"
- reduction_dims_name = namespace_prefix + "_reduction_dims"
-
- reshape_dims_node = create_constant_node(reshape_dims_name, -1,
- dtypes.int32, [1])
- self.add_output_graph_node(reshape_dims_node)
- reduction_dims_node = create_constant_node(reduction_dims_name, 0,
- dtypes.int32, [1])
- self.add_output_graph_node(reduction_dims_node)
- return reshape_dims_name, reduction_dims_name
-
- def eightbitize_input_to_node(self, namespace_prefix, original_input_name,
- reshape_dims_name, reduction_dims_name):
- """Takes one float input to an op, and converts it to quantized form."""
- unique_input_name = unique_node_name_from_input(original_input_name)
- reshape_input_name = namespace_prefix + "_reshape_" + unique_input_name
- min_input_name = namespace_prefix + "_min_" + unique_input_name
- max_input_name = namespace_prefix + "_max_" + unique_input_name
- quantize_input_name = namespace_prefix + "_quantize_" + unique_input_name
- reshape_input_node = create_node("Reshape", reshape_input_name,
- [original_input_name, reshape_dims_name])
- set_attr_dtype(reshape_input_node, "T", dtypes.float32)
- self.add_output_graph_node(reshape_input_node)
- min_input_node = create_node("Min", min_input_name,
- [reshape_input_name, reduction_dims_name])
- set_attr_dtype(min_input_node, "T", dtypes.float32)
- set_attr_bool(min_input_node, "keep_dims", False)
- self.add_output_graph_node(min_input_node)
- max_input_node = create_node("Max", max_input_name,
- [reshape_input_name, reduction_dims_name])
- set_attr_dtype(max_input_node, "T", dtypes.float32)
- set_attr_bool(max_input_node, "keep_dims", False)
- self.add_output_graph_node(max_input_node)
- quantize_input_node = create_node(
- "QuantizeV2", quantize_input_name,
- [original_input_name, min_input_name, max_input_name])
- set_attr_dtype(quantize_input_node, "T", dtypes.quint8)
- set_attr_string(quantize_input_node, "mode", b"MIN_FIRST")
- self.add_output_graph_node(quantize_input_node)
- min_output_name = quantize_input_name + ":1"
- max_output_name = quantize_input_name + ":2"
- return quantize_input_name, min_output_name, max_output_name
-
- def add_quantize_down_nodes(self, original_node, quantized_output_name):
- quantized_outputs = [
- quantized_output_name, quantized_output_name + ":1",
- quantized_output_name + ":2"
- ]
- min_max_inputs = None
- if self.should_merge_with_fake_quant_node():
- # Use the inputs to the FakeQuantWithMinMaxVars node as the inputs to
- # Requantize.
- fake_quant_node = self.state.output_node_stack[-1][0]
- min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]]
- assert original_node.name not in self.state.merged_with_fake_quant
- self.state.merged_with_fake_quant[original_node.name] = True
- elif self.fallback_quantization_range:
- min_max_inputs = [
- "fallback_quantization_min_value:0",
- "fallback_quantization_max_value:0"
- ]
- else:
- # Add a RequantizationRange node for finding the min and max values.
- requant_range_node = create_node(
- "RequantizationRange", original_node.name + "_eightbit_requant_range",
- quantized_outputs)
- set_attr_dtype(requant_range_node, "Tinput", dtypes.qint32)
- self.add_output_graph_node(requant_range_node)
- min_max_inputs = [
- requant_range_node.name + ":0", requant_range_node.name + ":1"
- ]
- requantize_node = create_node("Requantize",
- original_node.name + "_eightbit_requantize",
- quantized_outputs + min_max_inputs)
- set_attr_dtype(requantize_node, "Tinput", dtypes.qint32)
- set_attr_dtype(requantize_node, "out_type", dtypes.quint8)
- self.add_output_graph_node(requantize_node)
- return requantize_node.name
-
- def add_dequantize_result_node(self,
- quantized_output_name,
- original_node_name,
- min_tensor_index=1):
- min_max_inputs = [
- "%s:%s" % (quantized_output_name, min_tensor_index),
- "%s:%s" % (quantized_output_name, (min_tensor_index + 1))
- ]
- dequantize_name = original_node_name
- if self.should_merge_with_fake_quant_node():
- fake_quant_node = self.state.output_node_stack[-1][0]
- if original_node_name not in self.state.merged_with_fake_quant:
- min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]]
- self.state.merged_with_fake_quant[original_node_name] = True
- dequantize_name = fake_quant_node.name
-
- dequantize_node = create_node(
- "Dequantize", dequantize_name,
- [quantized_output_name, min_max_inputs[0], min_max_inputs[1]])
- set_attr_dtype(dequantize_node, "T", dtypes.quint8)
- set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
- self.add_output_graph_node(dequantize_node)
-
- def eightbitize_mat_mul_node(self, original_node):
- """Replaces a MatMul node with the eight bit equivalent sub-graph."""
- quantized_mat_mul_name = original_node.name + "_eightbit_quantized_mat_mul"
- all_input_names = self.add_eightbit_prologue_nodes(original_node)
- quantized_mat_mul_node = create_node("QuantizedMatMul",
- quantized_mat_mul_name,
- all_input_names)
- set_attr_dtype(quantized_mat_mul_node, "T1", dtypes.quint8)
- set_attr_dtype(quantized_mat_mul_node, "T2", dtypes.quint8)
- set_attr_dtype(quantized_mat_mul_node, "Toutput", dtypes.qint32)
- copy_attr(quantized_mat_mul_node, "transpose_a",
- original_node.attr["transpose_a"])
- copy_attr(quantized_mat_mul_node, "transpose_b",
- original_node.attr["transpose_b"])
- self.add_output_graph_node(quantized_mat_mul_node)
- quantize_down_name = self.add_quantize_down_nodes(original_node,
- quantized_mat_mul_name)
- self.add_dequantize_result_node(quantize_down_name, original_node.name)
-
- def eightbitize_conv_node(self, original_node):
- """Replaces a Conv2D node with the eight bit equivalent sub-graph."""
- all_input_names = self.add_eightbit_prologue_nodes(original_node)
- quantized_conv_name = original_node.name + "_eightbit_quantized_conv"
- quantized_conv_node = create_node("QuantizedConv2D", quantized_conv_name,
- all_input_names)
- copy_attr(quantized_conv_node, "strides", original_node.attr["strides"])
- copy_attr(quantized_conv_node, "padding", original_node.attr["padding"])
- set_attr_dtype(quantized_conv_node, "Tinput", dtypes.quint8)
- set_attr_dtype(quantized_conv_node, "Tfilter", dtypes.quint8)
- set_attr_dtype(quantized_conv_node, "out_type", dtypes.qint32)
- self.add_output_graph_node(quantized_conv_node)
- quantize_down_name = self.add_quantize_down_nodes(original_node,
- quantized_conv_name)
- self.add_dequantize_result_node(quantize_down_name, original_node.name)
-
- def eightbitize_bias_add_node(self, original_node):
- """Replaces a BiasAdd node with the eight bit equivalent sub-graph."""
- quantized_bias_add_name = (
- original_node.name + "_eightbit_quantized_bias_add")
- all_input_names = self.add_eightbit_prologue_nodes(original_node)
- quantized_bias_add_node = create_node("QuantizedBiasAdd",
- quantized_bias_add_name,
- all_input_names)
- set_attr_dtype(quantized_bias_add_node, "T1", dtypes.quint8)
- set_attr_dtype(quantized_bias_add_node, "T2", dtypes.quint8)
- set_attr_dtype(quantized_bias_add_node, "out_type", dtypes.qint32)
- self.add_output_graph_node(quantized_bias_add_node)
- quantize_down_name = self.add_quantize_down_nodes(original_node,
- quantized_bias_add_name)
- self.add_dequantize_result_node(quantize_down_name, original_node.name)
-
- def eightbitize_single_input_tensor_node(self, original_node,
- add_op_function):
- """Replaces a single-tensor node with the eight bit equivalent sub-graph.
-
- Converts a node like this:
-
- Shape(f) Input(f)
- | |
- +--------v v
- Operation
- |
- v
- (f)
-
- Into a quantized equivalent:
-
- Input(f) ReshapeDims
- +------v v-------------+
- | Reshape
- | |
- | | ReductionDims
- | +-----+ |
- | | +---c---------+
- | v v v v-------+
- | Min Max
- | +----+ |
- v v v--------+
- Quantize
- |
- v
- QuantizedOperation
- | | |
- v v v
- Dequantize
- |
- v
- (f)
-
-
- Args:
- original_node: Float node to be converted.
- add_op_function: Function to create the actual node.
-
- Returns:
- Subgraph representing the quantized version of the original node.
-
- """
- quantized_op_name = original_node.name + "_eightbit_quantized"
- quantized_op_type = "Quantized" + original_node.op
- all_input_names = self.add_eightbit_prologue_nodes(original_node)
- quantized_op_node = create_node(quantized_op_type, quantized_op_name,
- all_input_names)
- add_op_function(original_node, quantized_op_node)
- self.add_output_graph_node(quantized_op_node)
- self.add_dequantize_result_node(quantized_op_name, original_node.name)
-
- def add_pool_function(self, original_node, quantized_op_node):
- set_attr_dtype(quantized_op_node, "T", dtypes.quint8)
- copy_attr(quantized_op_node, "ksize", original_node.attr["ksize"])
- copy_attr(quantized_op_node, "strides", original_node.attr["strides"])
- copy_attr(quantized_op_node, "padding", original_node.attr["padding"])
-
- def add_relu_function(self, unused_arg_node, quantized_op_node):
- set_attr_dtype(quantized_op_node, "Tinput", dtypes.quint8)
-
- def eightbitize_concat_node(self, original_node):
- """Replaces a Concat node with the eight bit equivalent sub-graph.
-
- Converts a node like this:
-
- Shape(f) Input0(f) Input1(f)
- | | |
- +--------v v v----------+
- Concat
- |
- v
- (f)
-
- Into a quantized equivalent:
-
- Shape(f) Input0(f) ReshapeDims Input1(f)
- | +------v v--------------+------------------v v------+
- | | Reshape Reshape |
- | | | | |
- | | | ReductionDims | |
- | | +------+ | +--------+ |
- | | | +---c---------+-----------c-----+ | |
- | | +v v v v-------+---------v v v v+ |
- | | Min Max Min Max |
- | | +----+ | | +-----+ |
- | v v v--------+ +----------v v v
- | Quantize Quantize
- | +------------------+ +----------------------+
- +-------------------------------+ | |
- v v v
- QuantizedConcat
- | | |
- v v v
- Dequantize
- |
- v
- (f)
- Args:
- original_node: Float node to be converted.
-
- Returns:
- Subgraph representing the quantized version of the original node.
-
- """
- namespace_prefix = original_node.name + "_eightbit"
- quantized_concat_name = namespace_prefix + "_quantized_concat"
- reshape_dims_name, reduction_dims_name = self.add_common_quantization_nodes(
- namespace_prefix)
- shape_input_name = original_node.input[0]
- original_inputs = original_node.input[1:]
- input_names = []
- min_names = []
- max_names = []
- for original_input_name in original_inputs:
- quantize_input_name, min_input_name, max_input_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_input_name,
- reshape_dims_name,
- reduction_dims_name))
- input_names.append(quantize_input_name)
- min_names.append(min_input_name)
- max_names.append(max_input_name)
- all_input_names = [shape_input_name]
- all_input_names.extend(input_names)
- all_input_names.extend(min_names)
- all_input_names.extend(max_names)
- quantized_concat_node = create_node("QuantizedConcat",
- quantized_concat_name, all_input_names)
- set_attr_int(quantized_concat_node, "N", len(original_inputs))
- set_attr_dtype(quantized_concat_node, "T", dtypes.quint8)
- self.add_output_graph_node(quantized_concat_node)
- self.add_dequantize_result_node(quantized_concat_name, original_node.name)
-
- def eightbitize_placeholder_node(self, current_node):
- """Replaces a placeholder node with a quint8 placeholder node+dequantize."""
- name = current_node.name
-
- # Convert the placeholder into a quantized type.
- output_node = node_def_pb2.NodeDef()
- output_node.CopyFrom(current_node)
- set_attr_dtype(output_node, "dtype", dtypes.quint8)
- output_node.name += "_original_input"
- self.add_output_graph_node(output_node)
-
- # Add a dequantize to convert back to float.
- dequantize_node = create_node("Dequantize", name, [
- output_node.name, "quantized_input_min_value",
- "quantized_input_max_value"
- ])
- set_attr_dtype(dequantize_node, "T", dtypes.quint8)
- set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
- self.add_output_graph_node(dequantize_node)
-
- # For the descent over the graph to work, the dequantize node must be named
- # current_node.name. However, for the feeding of the graph to work, the
- # placeholder must have the name current_node.name; so record a final set
- # of renames to apply after all processing has been done.
- self.final_node_renames[output_node.name] = name
- self.final_node_renames[dequantize_node.name] = name + "_dequantize"
-
- def eightbitize_reshape_node(self, original_node):
- """Replaces a Reshape node with the eight bit equivalent sub-graph.
-
- Args:
- original_node: Float node to be converted.
-
- Returns:
- Subgraph representing the quantized version of the original node.
-
- """
- namespace_prefix = original_node.name + "_eightbit"
- quantized_reshape_name = namespace_prefix + "_quantized_reshape"
- reshape_dims_name, reduction_dims_name = self.add_common_quantization_nodes(
- namespace_prefix)
- shape_input_name = original_node.input[1]
- quantize_input_name, min_input_name, max_input_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_node.input[0],
- reshape_dims_name, reduction_dims_name))
- quantized_reshape_node = create_node(
- "QuantizedReshape", quantized_reshape_name,
- [quantize_input_name, shape_input_name, min_input_name, max_input_name])
- set_attr_dtype(quantized_reshape_node, "T", dtypes.quint8)
- self.add_output_graph_node(quantized_reshape_node)
- self.add_dequantize_result_node(quantized_reshape_name, original_node.name)
-
- def eightbitize_batch_norm_node(self, original_node):
- """Replaces a MatMul node with the eight bit equivalent sub-graph."""
- namespace_prefix = original_node.name + "_eightbit"
- original_input_name = original_node.input[0]
- original_mean_name = original_node.input[1]
- original_variance_name = original_node.input[2]
- original_beta_name = original_node.input[3]
- original_gamma_name = original_node.input[4]
- quantized_batch_norm_name = namespace_prefix + "_quantized_batch_norm"
-
- reshape_dims_name, reduction_dims_name = self.add_common_quantization_nodes(
- namespace_prefix)
- quantize_input_name, min_input_name, max_input_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_input_name,
- reshape_dims_name, reduction_dims_name))
- quantize_mean_name, min_mean_name, max_mean_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_mean_name,
- reshape_dims_name, reduction_dims_name))
- quantize_variance_name, min_variance_name, max_variance_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_variance_name,
- reshape_dims_name, reduction_dims_name))
- quantize_beta_name, min_beta_name, max_beta_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_beta_name,
- reshape_dims_name, reduction_dims_name))
- quantize_gamma_name, min_gamma_name, max_gamma_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_gamma_name,
- reshape_dims_name, reduction_dims_name))
- quantized_batch_norm_node = create_node(
- "QuantizedBatchNormWithGlobalNormalization", quantized_batch_norm_name,
- [
- quantize_input_name, min_input_name, max_input_name,
- quantize_mean_name, min_mean_name, max_mean_name,
- quantize_variance_name, min_variance_name, max_variance_name,
- quantize_beta_name, min_beta_name, max_beta_name,
- quantize_gamma_name, min_gamma_name, max_gamma_name
- ])
- set_attr_dtype(quantized_batch_norm_node, "Tinput", dtypes.quint8)
- set_attr_dtype(quantized_batch_norm_node, "out_type", dtypes.qint32)
- copy_attr(quantized_batch_norm_node, "scale_after_normalization",
- original_node.attr["scale_after_normalization"])
- copy_attr(quantized_batch_norm_node, "variance_epsilon",
- original_node.attr["variance_epsilon"])
- self.add_output_graph_node(quantized_batch_norm_node)
- quantize_down_name = self.add_quantize_down_nodes(original_node,
- quantized_batch_norm_name)
- self.add_dequantize_result_node(quantize_down_name, original_node.name)
-
- def add_output_graph_node(self, output_node):
- """Inserts one node into the new graph."""
- self.output_graph.node.extend([output_node])
-
- def remove_redundant_quantization(self, old_graph):
- """Removes unneeded pairs of quantize/dequantize ops from the graph.
-
- This is a bit of a tricky function, because it's attempting to spot the
- pattern of dequantizing from eight-bit up to float, and then immediately
- quantizing back down to eight bits again, that's introduced by previous
- passes that do 'key-hole' conversions of individual nodes but have to
- convert back to float to match the previous output interface, since they
- don't know that the next op can handle quantized tensors.
- It works by:
- - Looking for Quantize nodes.
- - Checking to see if their first input is a Dequantize node.
- - Seeing if their min/max inputs come from Min/Max nodes.
- - Making sure those Min/Max nodes are being fed from the same Dequantize.
- - Or that the Min is indirectly being fed from the same Dequantize as Max.
- - Making sure the Dequantize is going through a Reshape (which we add
- during the previous pass when we create the quantize sub-graph).
- - Looking for the dims Const op for the Min/Max dims.
- If all of these conditions are met, then it's a sub-graph pattern that
- we know how to optimize out (and is likely the common one we've introduced).
- We then rewire the graph to skip it entirely, and then rely on the dead node
- removal pass to get rid of any nodes that are no longer needed.
-
- Args:
- old_graph: The model we'll be stripping redundant nodes from.
-
- Returns:
- A graph with the unnecessary nodes removed.
-
- Raises:
- ValueError: Two nodes with the same name were found in the graph.
- """
- old_nodes_map = self.create_nodes_map(old_graph)
- self.output_graph = graph_pb2.GraphDef()
- inputs_to_rename = {}
- # We go through all the nodes, looking for any that match the patterns we
- # know how to optimize away.
- for node in old_graph.node:
- # We always start with a Quantize node, and examine its inputs to see if
- # they are in a form that can be removed.
- if node.op not in ["Quantize", "QuantizeV2"]:
- continue
- dequantize_node_name = node_name_from_input(node.input[0])
- if dequantize_node_name not in old_nodes_map:
- raise ValueError("Input node name '" + dequantize_node_name +
- "' not found in node '" + node.name + "'")
- dequantize_node = old_nodes_map[dequantize_node_name]
- # Do we have a Dequantize feeding in, with the same type as the Quantize?
- if dequantize_node.op != "Dequantize":
- continue
- if node.attr["T"] != dequantize_node.attr["T"]:
- continue
- # Now look at the other inputs, and ensure they're Min/Max nodes.
- min_node_name = node_name_from_input(node.input[1])
- max_node_name = node_name_from_input(node.input[2])
- min_node = old_nodes_map[min_node_name]
- max_node = old_nodes_map[max_node_name]
- is_min_right_type = (min_node.op in ["Min", "Dequantize"])
- is_max_right_type = (max_node.op in ["Max", "Dequantize"])
- if not is_min_right_type or not is_max_right_type:
- print("Didn't find expected types on inputs : %s, %s." % (min_node.op,
- max_node.op))
- continue
- min_node_input_name = node_name_from_input(min_node.input[0])
- max_node_input_name = node_name_from_input(max_node.input[0])
- # There are two different patterns for Min nodes we can recognize, one
- # where the input comes directly from the same one as the Max, and
- # another where we run it through another Min first, so check for both.
- is_same_input = False
- if min_node_input_name == max_node_input_name:
- is_same_input = True
- else:
- first_min_node_input = old_nodes_map[min_node_input_name]
- if first_min_node_input.op == "Concat":
- second_min_node_name = node_name_from_input(
- first_min_node_input.input[1])
- second_min_node = old_nodes_map[second_min_node_name]
- if second_min_node.op == "Min":
- second_min_node_input_name = node_name_from_input(
- second_min_node.input[0])
- is_same_input = (second_min_node_input_name == max_node_input_name)
- if not is_same_input:
- print("Different min/max inputs: " + min_node_input_name)
- continue
- # We recognize this pattern, so mark the graph edges to be rewired to
- # route around it entirely, since we know it's a no-op.
- dequantize_source_name = node_name_from_input(dequantize_node.input[0])
- node_tensor_name = ensure_tensor_name_has_port(node.name)
- min_tensor_name = node.name + ":1"
- max_tensor_name = node.name + ":2"
- inputs_to_rename[node_tensor_name] = dequantize_source_name
- inputs_to_rename[min_tensor_name] = dequantize_node.input[1]
- inputs_to_rename[max_tensor_name] = dequantize_node.input[2]
- # Finally we apply all the rewiring we've marked to the graph.
- for node in old_graph.node:
- for index, input_full_name in enumerate(node.input):
- input_name = ensure_tensor_name_has_port(input_full_name)
- if input_name in inputs_to_rename:
- node.input[index] = inputs_to_rename[input_name]
- self.add_output_graph_node(node)
- return self.output_graph
-
- def apply_final_node_renames(self):
- """Applies node renames in self.final_node_renames to self.output_graph."""
- old_graph = self.output_graph
- self.output_graph = graph_pb2.GraphDef()
- for node in old_graph.node:
- node.name = self.final_node_renames.get(node.name, node.name)
- for index, input_name in enumerate(node.input):
- node_name = node_name_from_input(input_name)
- input_full_name = ensure_tensor_name_has_port(input_name)
- if node_name in self.final_node_renames:
- node.input[index] = "%s%s" % (self.final_node_renames[node_name],
- input_full_name[len(node_name):])
- self.add_output_graph_node(node)
- return self.output_graph
-
- def remove_dead_nodes(self, output_names):
- """Removes nodes that are no longer needed for inference from the graph."""
- old_output_graph = self.output_graph
- self.output_graph = graph_util.extract_sub_graph(old_output_graph,
- output_names)
-
- def quantize_weights(self, input_graph, quantization_mode):
- """Quantize float Const ops.
-
- There are two modes of operations, both replace float Const ops with
- quantized values.
- 1. If quantization_mode is "weights_rounded", this function replaces float
- Const ops with quantized float Const ops - same as the original op, but
- float values being mapped to the center of one of 1<<FLAGS.bitdepth buckets.
- This does not change the raw model size, but compression algorithms such as
- zip (as used for compressing apks) or bzip2 will achieve a very good
- compression ratio.
- 2. For other quantization modes ("MIN_COMBINED" or "MIN_FIRST"), float
- Const ops are quantized and replaced by a tuple of four ops to perform
- the dequantization at runtime:
- * eight-bit Const (bucket indices, same shape as original float Const op
- * two float Const ops (min and max value of original float Const op)
- * Dequantize op to convert the eight-bit consts to float tensors.
- The quantization mode is important because we see accuracy problems when
- quantizing weights for different situations depending on the algorithm
- used. We haven't figured out exactly what the underlying cause is yet,
- unfortunately.
-
- Args:
- input_graph: A GraphDef of the model containing float Const ops.
- quantization_mode: How to quantize and dequantize the values.
-
- Returns:
- A GraphDef of the converted graph.
-
- Raises:
- ValueError: If quantization_mode is unsupported.
- """
- output_graph = graph_pb2.GraphDef()
- for input_node in input_graph.node:
- should_quantize = False
- if input_node.op == "Const":
- dtype = dtypes.as_dtype(input_node.attr["dtype"].type)
- if dtype == dtypes.float32:
- should_quantize = True
- if should_quantize:
- if quantization_mode == "weights_rounded":
- output_graph.node.extend(quantize_weight_rounded(input_node))
- elif quantization_mode in (b"MIN_COMBINED", b"MIN_FIRST"):
- output_graph.node.extend(
- quantize_weight_eightbit(input_node, quantization_mode))
- else:
- raise ValueError("Unsupported quantization mode %s." %
- quantization_mode)
- else:
- output_node = node_def_pb2.NodeDef()
- output_node.CopyFrom(input_node)
- output_graph.node.extend([output_node])
- return output_graph
-
- def set_input_graph(self, new_input_graph):
- self.input_graph = new_input_graph
- self.nodes_map = self.create_nodes_map(self.input_graph)
-
-
-def main(unused_args):
- if not gfile.Exists(FLAGS.input):
- print("Input graph file '" + FLAGS.input + "' does not exist!")
- return -1
-
- known_modes = [
- "round", "quantize", "eightbit", "weights", "test", "weights_rounded"
- ]
- if not any(FLAGS.mode in s for s in known_modes):
- print("mode is '" + FLAGS.mode + "', not in " + ", ".join(known_modes) +
- ".")
- return -1
-
- tf_graph = graph_pb2.GraphDef()
- with gfile.Open(FLAGS.input, "rb") as f:
- data = f.read()
- tf_graph.ParseFromString(data)
-
- graph = ops.Graph()
- with graph.as_default():
- importer.import_graph_def(tf_graph, input_map={}, name="")
-
- quantized_input_range = None
- if FLAGS.quantized_input:
- quantized_input_range = [
- FLAGS.quantized_input_min, FLAGS.quantized_input_max
- ]
-
- fallback_quantization_range = None
- if (FLAGS.quantized_fallback_min is not None or
- FLAGS.quantized_fallback_max is not None):
- assert FLAGS.quantized_fallback_min is not None
- assert FLAGS.quantized_fallback_max is not None
- fallback_quantization_range = [
- FLAGS.quantized_fallback_min, FLAGS.quantized_fallback_max
- ]
-
- rewriter = GraphRewriter(tf_graph, FLAGS.mode, quantized_input_range,
- fallback_quantization_range)
-
- output_graph = rewriter.rewrite(FLAGS.output_node_names.split(","))
-
- f = gfile.FastGFile(FLAGS.output, "wb")
- f.write(output_graph.SerializeToString())
-
- return 0
-
-
-if __name__ == "__main__":
- app.run()
diff --git a/tensorflow/tools/quantization/quantize_graph_test.py b/tensorflow/tools/quantization/quantize_graph_test.py
deleted file mode 100644
index 92bb5127da..0000000000
--- a/tensorflow/tools/quantization/quantize_graph_test.py
+++ /dev/null
@@ -1,966 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests the graph quantization script.
-
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import sys
-import numpy as np
-
-from tensorflow.core.framework import graph_pb2
-from tensorflow.python.client import session
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import graph_util
-from tensorflow.python.framework import importer
-from tensorflow.python.framework import ops as ops_lib
-from tensorflow.python.platform import flags as flags_lib
-from tensorflow.python.platform import test
-from tensorflow.python.platform import tf_logging
-from tensorflow.tools.quantization import quantize_graph
-
-flags = flags_lib
-FLAGS = flags.FLAGS
-
-
-def run_graph_def(graph_def, input_map, outputs):
- graph = ops_lib.Graph()
- with graph.as_default():
- importer.import_graph_def(graph_def, input_map={}, name="")
- with session.Session(graph=graph) as sess:
- results = sess.run(outputs, feed_dict=input_map)
- return results
-
-
-def test_mat_mul(m, n, k, a, b):
- """Tests a MatMul replacement."""
- a_constant_name = "a_constant"
- b_constant_name = "b_constant"
- mat_mul_name = "mat_mul"
-
- float_graph_def = graph_pb2.GraphDef()
- a_constant = quantize_graph.create_constant_node(
- a_constant_name, value=a, dtype=dtypes.float32, shape=[m, k])
- float_graph_def.node.extend([a_constant])
- b_constant = quantize_graph.create_constant_node(
- b_constant_name, value=b, dtype=dtypes.float32, shape=[k, n])
- float_graph_def.node.extend([b_constant])
- mat_mul_node = quantize_graph.create_node("MatMul", mat_mul_name,
- [a_constant_name, b_constant_name])
- quantize_graph.set_attr_dtype(mat_mul_node, "T", dtypes.float32)
- quantize_graph.set_attr_bool(mat_mul_node, "transpose_a", False)
- quantize_graph.set_attr_bool(mat_mul_node, "transpose_b", False)
- float_graph_def.node.extend([mat_mul_node])
-
- test_graph(float_graph_def, {}, [mat_mul_name])
-
-
-def test_conv(depth, image_width, image_height, image_batch_count, filter_size,
- filter_count, stride, padding, input_values, filter_values):
- """Tests a Conv replacement."""
- input_constant_name = "input_constant"
- filter_constant_name = "filter_constant"
- conv_name = "conv"
-
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=input_values,
- dtype=dtypes.float32,
- shape=[image_batch_count, image_height, image_width, depth])
- float_graph_def.node.extend([input_constant])
- filter_constant = quantize_graph.create_constant_node(
- filter_constant_name,
- value=filter_values,
- dtype=dtypes.float32,
- shape=[filter_size, filter_size, depth, filter_count])
- float_graph_def.node.extend([filter_constant])
- conv_node = quantize_graph.create_node(
- "Conv2D", conv_name, [input_constant_name, filter_constant_name])
- quantize_graph.set_attr_dtype(conv_node, "T", dtypes.float32)
- quantize_graph.set_attr_int_list(conv_node, "strides", [1, stride, stride, 1])
- quantize_graph.set_attr_string(conv_node, "padding", padding)
- float_graph_def.node.extend([conv_node])
-
- test_graph(float_graph_def, {}, [conv_name])
-
-
-def are_tensors_near(a, b, tolerance):
- """Tests whether two tensors are nearly identical.
-
- This is a specialized comparison function designed to help debug problems with
- quantization. It prints out information about the differences between tensors
- on failure, paying special attention to possible biases by looking at the mean
- and absolute average errors.
-
- Args:
- a: First comparison tensor.
- b: Second comparison tensor.
- tolerance: Float value indicating how large an error between values is ok.
-
- Returns:
- Boolean indicating whether the two inputs were close enough.
- """
- flat_a = a.flatten()
- flat_b = b.flatten()
- if len(flat_a) != len(flat_b):
- tf_logging.info("Tensors are different sizes: " + str(len(flat_a)) + " vs "
- + str(len(flat_b)))
- return False
- value_count = len(flat_a)
- how_many_different = 0
- total_difference = 0
- total_abs_difference = 0
- for index in range(value_count):
- a_value = flat_a[index]
- b_value = flat_b[index]
- difference = a_value - b_value
- total_difference += difference
- total_abs_difference += abs(difference)
- if abs(difference) > tolerance:
- how_many_different += 1
- mean_difference = total_difference / value_count
- mean_abs_difference = total_abs_difference / value_count
- proportion_different = (how_many_different * 1.0) / value_count
- if how_many_different == 0:
- return True
- else:
- tf_logging.info("Tensors have {0} different values ({1}%), with mean"
- " difference {2} and mean absolute difference {3}".format(
- how_many_different, proportion_different * 100,
- mean_difference, mean_abs_difference))
- return False
-
-
-def get_top_value(input_values):
- max_value = None
- max_index = None
- for index, value in enumerate(input_values.flatten()):
- if max_value is None or value > max:
- max_value = value
- max_index = index
- return max_index, max_value
-
-
-def test_graph(float_graph_def, input_map, output_names, log_graph=False):
- """Runs the float graph through the rewriter and tests the results."""
- float_results = run_graph_def(
- float_graph_def, input_map,
- [output_name + ":0" for output_name in output_names])
- # TODO(petewarden): round test is currently failing because there is no
- # RoundToSteps op available.
- # round_rewriter = quantize_graph.GraphRewriter(float_graph_def, "round")
- # round_graph_def = round_rewriter.rewrite(output_name)
- # round_results = run_graph_def(round_graph_def, input_map,
- # [output_name + ":0"])
- # assert are_tensors_near(expected, round_results[0], 1.0)
- #
- # TODO(petewarden): Add test for "quantize" mode.
-
- eightbit_rewriter = quantize_graph.GraphRewriter(
- float_graph_def, "eightbit", quantized_input_range=None)
- eightbit_graph_def = eightbit_rewriter.rewrite(output_names)
- eightbit_results = run_graph_def(
- eightbit_graph_def, input_map,
- [output_name + ":0" for output_name in output_names])
- for expected, result in zip(float_results, eightbit_results):
- assert are_tensors_near(expected, result, 1.0)
-
- if log_graph:
- tf_logging.info("8bit:\n%s", str(eightbit_graph_def))
-
- # Test the weights_rounded mode. This uses the default bit_depth.
- weights_rounded_rewriter = quantize_graph.GraphRewriter(
- float_graph_def, "weights_rounded", quantized_input_range=None)
- weights_rounded_graph_def = weights_rounded_rewriter.rewrite(output_names)
- weights_rounded_results = run_graph_def(
- weights_rounded_graph_def, input_map,
- [output_name + ":0" for output_name in output_names])
- for expected, result in zip(float_results, weights_rounded_results):
- assert are_tensors_near(expected, result, 1.0)
-
-
-class QuantizeGraphTest(test.TestCase):
-
- def test_negative_const_problem(self):
- shape_constant_name = "shape_constant"
- shape_constant = quantize_graph.create_constant_node(
- shape_constant_name, value=-0.8, dtype=dtypes.float32, shape=[1])
- quantization_result = quantize_graph.quantize_weight_eightbit(
- shape_constant, b"MIN_COMBINED")
- self.assertEqual(4, len(quantization_result))
-
- def test_odd_padding_problem(self):
- """Tests one error case we ran into in a real graph."""
- test_conv(1, 4, 4, 1, 3, 1, 2, b"SAME",
- [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
- [1, 2, 3, 4, 5, 6, 7, 8, 9])
-
- def test_mat_mul_tiny(self):
- # These tests are added to test the generate case where
- # min(matrix) == max(matrix), which used to cause problems.
- test_mat_mul(1, 1, 1, [2], [3])
- test_mat_mul(1, 2, 1, [1], [2, 3])
- test_mat_mul(1, 1, 2, [1, 1], [1, 1])
- test_mat_mul(1, 1, 2, [0, 0], [1, 1])
- # The general case.
- test_mat_mul(1, 1, 2, [1, 2], [1, 2])
-
- def test_mat_mul_small(self):
- test_mat_mul(2, 4, 3, [1, 2, 3, 4, 5, 6],
- [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18])
-
- def test_conv(self):
- test_conv(1, 4, 3, 1, 3, 1, 1, b"SAME",
- [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- [1, 4, 7, 2, 5, 8, 3, 6, 9])
-
- def test_reshape(self):
- """Tests that MatMul->Reshape->MatMul avoids extra quantize/dequantize."""
-
- def make_matmul(name, a, b):
- n = quantize_graph.create_node("MatMul", name, [a.name, b.name])
- quantize_graph.set_attr_dtype(n, "T", dtypes.float32)
- quantize_graph.set_attr_bool(n, "transpose_a", False)
- quantize_graph.set_attr_bool(n, "transpose_b", False)
- return n
-
- # matmul_1 = input*weight_1
- input_node = quantize_graph.create_constant_node(
- "input", value=[0, 1, 2, 3], dtype=dtypes.float32, shape=[4, 1])
- weight_1_node = quantize_graph.create_constant_node(
- "weight_1",
- value=[.5, .6, .7, .8, .9],
- dtype=dtypes.float32,
- shape=[1, 5])
- matmul_1_node = make_matmul("matmul_1", input_node, weight_1_node)
-
- # Reshape 4x5 to 10x2.
- new_shape_node = quantize_graph.create_constant_node(
- "new_shape_node", value=[10, 2], dtype=dtypes.int32, shape=[2])
- reshape_node = quantize_graph.create_node(
- "Reshape", "reshape", [matmul_1_node.name, new_shape_node.name])
- quantize_graph.set_attr_dtype(reshape_node, "T", dtypes.float32)
-
- # matmul_2_node = reshape*weight_2
- weight_2_node = quantize_graph.create_constant_node(
- "weight_2", value=[1.5, 2.5], dtype=dtypes.float32, shape=[2, 1])
- matmul_2_node = make_matmul("matmul_2", reshape_node, weight_2_node)
-
- g = graph_pb2.GraphDef()
- g.node.extend([
- input_node, weight_1_node, matmul_1_node, new_shape_node, reshape_node,
- weight_2_node, matmul_2_node
- ])
-
- # Test the graph
- test_graph(g, {}, ["matmul_2"])
-
- # Verify there is only one Quantize and one Requantize op.
- eightbit_rewriter = quantize_graph.GraphRewriter(
- g, "eightbit", quantized_input_range=None)
- eightbit_graph_def = eightbit_rewriter.rewrite(["matmul_2"])
-
- ops = [node.op for node in eightbit_graph_def.node]
- # No quantize since all inputs are const and can be quantized up-front.
- self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
- self.assertEqual(1, ops.count("QuantizedReshape"))
-
- # One dequantize at the end.
- self.assertEqual(1, ops.count("Dequantize"))
-
- def test_quantize_array(self):
- # Test invalid parameters (empty array, or 0 buckets.
- self.assertRaises(ValueError, quantize_graph.quantize_array, np.array([]),
- 2)
- self.assertRaises(ValueError, quantize_graph.quantize_array,
- np.array([1, 2]), 0)
- # Test input array of length 1.
- arr = np.array([1])
- qarr = quantize_graph.quantize_array(arr, 1)
- self.assertEqual(arr, qarr)
- qarr = quantize_graph.quantize_array(arr, 2)
- self.assertEqual(arr, qarr)
- # Test input array with all elements equal.
- arr = np.array([1, 1, 1])
- qarr = quantize_graph.quantize_array(arr, 10)
- self.assertTrue((np.array([1, 1, 1]) == qarr).all())
- # Test "normal" input arrays.
- arr = np.array([0, 0.3, 0.6, 1])
- qarr = quantize_graph.quantize_array(arr, 1)
- self.assertTrue((np.array([0.5, 0.5, 0.5, 0.5]) == qarr).all())
- qarr = quantize_graph.quantize_array(arr, 2)
- self.assertTrue((np.array([0.25, 0.25, 0.75, 0.75]) == qarr).all())
- qarr = quantize_graph.quantize_array(arr.reshape((2, 2)), 2)
- self.assertTrue((np.array([[0.25, 0.25], [0.75, 0.75]]) == qarr).all())
-
- def test_non_float_concat(self):
- concat_dim = quantize_graph.create_constant_node(
- "concat_dim", value=0, dtype=dtypes.int32, shape=[])
- a = quantize_graph.create_constant_node(
- "a",
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.int32,
- shape=[2, 2, 3])
- b = quantize_graph.create_constant_node(
- "b",
- value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
- dtype=dtypes.int32,
- shape=[2, 2, 3])
- concat = quantize_graph.create_node("Concat", "concat",
- [concat_dim.name, a.name, b.name])
- quantize_graph.set_attr_int(concat, "N", 2)
- quantize_graph.set_attr_dtype(concat, "T", dtypes.int32)
-
- g = graph_pb2.GraphDef()
- g.node.extend([concat_dim, a, b, concat])
- test_graph(g, {}, [concat.name])
-
- def test_non_float_reshape(self):
- a = quantize_graph.create_constant_node(
- "a",
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.int32,
- shape=[2, 2, 3])
- shape = quantize_graph.create_constant_node(
- "shape", value=[12], dtype=dtypes.int32, shape=[1])
- reshape = quantize_graph.create_node("Reshape", "reshape",
- [a.name, shape.name])
- quantize_graph.set_attr_dtype(reshape, "T", dtypes.int32)
-
- g = graph_pb2.GraphDef()
- g.node.extend([a, shape, reshape])
- test_graph(g, {}, [reshape.name])
-
- def test_concat(self):
- shape_constant_name = "shape_constant"
- a_constant_name = "a_constant"
- b_constant_name = "b_constant"
- concat_name = "concat"
-
- float_graph_def = graph_pb2.GraphDef()
- shape_constant = quantize_graph.create_constant_node(
- shape_constant_name, value=0, dtype=dtypes.int32, shape=[])
- float_graph_def.node.extend([shape_constant])
- a_constant = quantize_graph.create_constant_node(
- a_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[2, 2, 3])
- float_graph_def.node.extend([a_constant])
- b_constant = quantize_graph.create_constant_node(
- b_constant_name,
- value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
- dtype=dtypes.float32,
- shape=[2, 2, 3])
- float_graph_def.node.extend([b_constant])
- concat_node = quantize_graph.create_node(
- "Concat", concat_name,
- [shape_constant_name, a_constant_name, b_constant_name])
- quantize_graph.set_attr_int(concat_node, "N", 2)
- quantize_graph.set_attr_dtype(concat_node, "T", dtypes.float32)
- float_graph_def.node.extend([concat_node])
-
- test_graph(float_graph_def, {}, [concat_name])
-
- # Verify the concat is quantized.
- eightbit_rewriter = quantize_graph.GraphRewriter(
- float_graph_def, "eightbit", quantized_input_range=None)
- eightbit_graph_def = eightbit_rewriter.rewrite([concat_name])
-
- ops = [node.op for node in eightbit_graph_def.node]
- self.assertEqual(1, ops.count("QuantizedConcat"))
-
- def test_multiple_outputs(self):
- input_constant_name = "input_constant"
- split_constant_name = "split_constant"
- split_name = "split"
- concat_constant_name = "concat_constant"
- concat_name = "concat"
-
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[2, 6])
- float_graph_def.node.extend([input_constant])
- split_constant = quantize_graph.create_constant_node(
- split_constant_name, value=1, dtype=dtypes.int32, shape=[])
- float_graph_def.node.extend([split_constant])
- split_node = quantize_graph.create_node(
- "Split", split_name, [split_constant_name, input_constant_name])
- quantize_graph.set_attr_int(split_node, "num_split", 2)
- quantize_graph.set_attr_dtype(split_node, "T", dtypes.float32)
- float_graph_def.node.extend([split_node])
- concat_constant = quantize_graph.create_constant_node(
- concat_constant_name, value=1, dtype=dtypes.int32, shape=[])
- float_graph_def.node.extend([concat_constant])
- concat_node = quantize_graph.create_node(
- "Concat", concat_name,
- [concat_constant_name, split_name + ":0", split_name + ":1"])
- quantize_graph.set_attr_int(concat_node, "N", 2)
- quantize_graph.set_attr_dtype(concat_node, "T", dtypes.float32)
- float_graph_def.node.extend([concat_node])
-
- test_graph(float_graph_def, {}, [concat_name])
-
- def test_node_name_from_input(self):
- self.assertEqual("SomeName",
- quantize_graph.node_name_from_input("^SomeName:2"))
-
- def test_unique_node_name_from_input(self):
- self.assertEqual("__hat__SomeName__port__2",
- quantize_graph.unique_node_name_from_input("^SomeName:2"))
-
- def test_identity(self):
- input_constant_name = "input_constant"
- identity_name = "identity"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[2, 6])
- float_graph_def.node.extend([input_constant])
- identity_node = quantize_graph.create_node("Identity", identity_name,
- [input_constant_name])
- quantize_graph.set_attr_dtype(identity_node, "T", dtypes.float32)
- float_graph_def.node.extend([identity_node])
-
- mul_name = "mul"
- mul_node = quantize_graph.create_node("Mul", mul_name,
- [identity_name, identity_name])
- quantize_graph.set_attr_dtype(mul_node, "T", dtypes.float32)
- float_graph_def.node.extend([mul_node])
-
- test_graph(float_graph_def, {}, [mul_name])
-
- def test_keep_control_edges(self):
- no_op_name = "no_op"
- a_constant_name = "a_constant"
- b_constant_name = "b_constant"
- a_check_name = "a_check"
- b_check_name = "b_check"
- a_identity_name = "a_identity"
- b_identity_name = "b_identity"
- add_name = "add"
- graph_def = graph_pb2.GraphDef()
- no_op = quantize_graph.create_node("NoOp", no_op_name, [])
- graph_def.node.extend([no_op])
- a_constant = quantize_graph.create_constant_node(
- a_constant_name, value=1, dtype=dtypes.float32, shape=[])
- graph_def.node.extend([a_constant])
- a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name,
- [a_constant_name])
- graph_def.node.extend([a_check_node])
- a_identity_node = quantize_graph.create_node(
- "Identity", a_identity_name,
- [a_constant_name, "^" + a_check_name, "^" + no_op_name])
- graph_def.node.extend([a_identity_node])
- b_constant = quantize_graph.create_constant_node(
- b_constant_name, value=1, dtype=dtypes.float32, shape=[])
- graph_def.node.extend([b_constant])
- b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name,
- [b_constant_name])
- graph_def.node.extend([b_check_node])
- b_identity_node = quantize_graph.create_node(
- "Identity", b_identity_name, [b_constant_name, "^" + b_check_name])
- graph_def.node.extend([b_identity_node])
- add_node = quantize_graph.create_node("Add", add_name,
- [a_identity_name, b_identity_name])
- quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32)
- graph_def.node.extend([add_node])
-
- expected_output = graph_pb2.GraphDef()
- no_op = quantize_graph.create_node("NoOp", no_op_name, [])
- expected_output.node.extend([no_op])
- a_constant = quantize_graph.create_constant_node(
- a_constant_name, value=1, dtype=dtypes.float32, shape=[])
- expected_output.node.extend([a_constant])
- a_identity_node = quantize_graph.create_node(
- "Identity", a_identity_name, [a_constant_name, "^" + no_op_name])
- expected_output.node.extend([a_identity_node])
- b_constant = quantize_graph.create_constant_node(
- b_constant_name, value=1, dtype=dtypes.float32, shape=[])
- expected_output.node.extend([b_constant])
- add_node = quantize_graph.create_node("Add", add_name,
- [a_identity_name, b_constant_name])
- quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32)
- expected_output.node.extend([add_node])
- expected_output.versions.CopyFrom(graph_def.versions)
- expected_output.library.CopyFrom(graph_def.library)
-
- output = graph_util.remove_training_nodes(graph_def)
- stripped_output = graph_util.extract_sub_graph(output, [add_name])
- self.assertProtoEquals(expected_output, stripped_output)
-
- def test_batch_norm(self):
- input_constant_name = "input_constant"
- mean_constant_name = "mean_constant"
- variance_constant_name = "variance_constant"
- beta_constant_name = "beta_constant"
- gamma_constant_name = "gamma_constant"
- batch_norm_name = "batch_norm"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6],
- dtype=dtypes.float32,
- shape=[1, 1, 6, 2])
- float_graph_def.node.extend([input_constant])
- mean_constant = quantize_graph.create_constant_node(
- mean_constant_name, value=[10, 20], dtype=dtypes.float32, shape=[2])
- float_graph_def.node.extend([mean_constant])
- variance_constant = quantize_graph.create_constant_node(
- variance_constant_name,
- value=[0.25, 0.5],
- dtype=dtypes.float32,
- shape=[2])
- float_graph_def.node.extend([variance_constant])
- beta_constant = quantize_graph.create_constant_node(
- beta_constant_name, value=[0.1, 0.6], dtype=dtypes.float32, shape=[2])
- float_graph_def.node.extend([beta_constant])
- gamma_constant = quantize_graph.create_constant_node(
- gamma_constant_name, value=[0, 0], dtype=dtypes.float32, shape=[2])
- float_graph_def.node.extend([gamma_constant])
- batch_norm_node = quantize_graph.create_node(
- "BatchNormWithGlobalNormalization", batch_norm_name, [
- input_constant_name, mean_constant_name, variance_constant_name,
- beta_constant_name, gamma_constant_name
- ])
- quantize_graph.set_attr_dtype(batch_norm_node, "T", dtypes.float32)
- quantize_graph.set_attr_bool(batch_norm_node, "scale_after_normalization",
- False)
- quantize_graph.set_attr_float(batch_norm_node, "variance_epsilon", 0.001)
- float_graph_def.node.extend([batch_norm_node])
- test_graph(float_graph_def, {}, [batch_norm_name])
-
- def test_max_pool(self):
- input_constant_name = "input_constant"
- max_pool_name = "max_pool"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[1, 2, 6, 1])
- float_graph_def.node.extend([input_constant])
- max_pool_node = quantize_graph.create_node("MaxPool", max_pool_name,
- [input_constant_name])
- quantize_graph.set_attr_int_list(max_pool_node, "ksize", [1, 2, 2, 1])
- quantize_graph.set_attr_int_list(max_pool_node, "strides", [1, 1, 1, 1])
- quantize_graph.set_attr_string(max_pool_node, "padding", b"SAME")
- float_graph_def.node.extend([max_pool_node])
- test_graph(float_graph_def, {}, [max_pool_name])
-
- def test_avg_pool(self):
- input_constant_name = "input_constant"
- avg_pool_name = "avg_pool"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[1, 2, 6, 1])
- float_graph_def.node.extend([input_constant])
- avg_pool_node = quantize_graph.create_node("AvgPool", avg_pool_name,
- [input_constant_name])
- quantize_graph.set_attr_dtype(avg_pool_node, "T", dtypes.float32)
- quantize_graph.set_attr_int_list(avg_pool_node, "ksize", [1, 2, 2, 1])
- quantize_graph.set_attr_int_list(avg_pool_node, "strides", [1, 1, 1, 1])
- quantize_graph.set_attr_string(avg_pool_node, "padding", b"SAME")
- float_graph_def.node.extend([avg_pool_node])
- test_graph(float_graph_def, {}, [avg_pool_name])
-
- def test_relu(self):
- input_constant_name = "input_constant"
- relu_name = "relu"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[1, 2, 6, 1])
- float_graph_def.node.extend([input_constant])
- relu_node = quantize_graph.create_node("Relu", relu_name,
- [input_constant_name])
- quantize_graph.set_attr_dtype(relu_node, "T", dtypes.float32)
- float_graph_def.node.extend([relu_node])
- test_graph(float_graph_def, {}, [relu_name])
-
- def test_relu_w_fake_quant_w_min_max_vars(self):
- input_node = quantize_graph.create_constant_node(
- "input",
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[1, 2, 6, 1])
- relu_node = quantize_graph.create_node("Relu", "relu", [input_node.name])
- quantize_graph.set_attr_dtype(relu_node, "T", dtypes.float32)
-
- min_node = quantize_graph.create_constant_node(
- "min_bias_add", value=0, dtype=dtypes.float32, shape=[])
- max_node = quantize_graph.create_constant_node(
- "max_bias_add", value=12, dtype=dtypes.float32, shape=[])
- fake_quant_node = quantize_graph.create_node(
- "FakeQuantWithMinMaxVars", "fake_quant",
- [relu_node.name, min_node.name, max_node.name])
-
- float_graph_def = graph_pb2.GraphDef()
- float_graph_def.node.extend(
- [input_node, relu_node, min_node, max_node, fake_quant_node])
- test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True)
-
- # Verify there is only one Quantize and one Requantize op.
- eightbit_rewriter = quantize_graph.GraphRewriter(
- float_graph_def, "eightbit", quantized_input_range=None)
- eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name])
-
- ops = [node.op for node in eightbit_graph_def.node]
- # No quantize since all inputs are const and can be quantized up-front.
- self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
-
- # One dequantize at the end.
- self.assertEqual(1, ops.count("Dequantize"))
-
- def test_relu6(self):
- input_constant_name = "input_constant"
- relu6_name = "relu6"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[1, 2, 6, 1])
- float_graph_def.node.extend([input_constant])
- relu6_node = quantize_graph.create_node("Relu6", relu6_name,
- [input_constant_name])
- quantize_graph.set_attr_dtype(relu6_node, "T", dtypes.float32)
- float_graph_def.node.extend([relu6_node])
- test_graph(float_graph_def, {}, [relu6_name])
-
- def test_bias_add(self):
- input_constant_name = "input_constant"
- offset_constant_name = "offset_constant"
- bias_add_name = "bias_add"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[1, 1, 2, 6])
- float_graph_def.node.extend([input_constant])
- offset_constant = quantize_graph.create_constant_node(
- offset_constant_name,
- value=[1, 2, 3, 4, 5, 6],
- dtype=dtypes.float32,
- shape=[6])
- float_graph_def.node.extend([offset_constant])
- bias_add_node = quantize_graph.create_node(
- "BiasAdd", bias_add_name, [input_constant_name, offset_constant_name])
- quantize_graph.set_attr_dtype(bias_add_node, "T", dtypes.float32)
- float_graph_def.node.extend([bias_add_node])
- test_graph(float_graph_def, {}, [bias_add_name])
-
- def test_quantized_input_range_errors(self):
- with self.assertRaises(ValueError):
- # Invalid mode.
- quantize_graph.GraphRewriter(graph_pb2.GraphDef(), "weights_rounded",
- [0, 1])
- with self.assertRaises(ValueError):
- # Invalid range.
- quantize_graph.GraphRewriter(graph_pb2.GraphDef(), "eightbit", [0, -1])
-
- def test_quantized_input_range_bias_add(self):
- input_shape = [1, 1, 2, 6]
- input_n = quantize_graph.create_node("Placeholder", "input", [])
- quantize_graph.set_attr_dtype(input_n, "dtype", dtypes.float32)
- quantize_graph.set_attr_shape(input_n, "shape", input_shape)
- offset_n = quantize_graph.create_constant_node(
- "offset", value=[1, 2, 3, 4, 5, 6], dtype=dtypes.float32, shape=[6])
- bias_add_n = quantize_graph.create_node("BiasAdd", "bias_add",
- [input_n.name, offset_n.name])
- quantize_graph.set_attr_dtype(bias_add_n, "T", dtypes.float32)
-
- float_graph_def = graph_pb2.GraphDef()
- float_graph_def.node.extend([input_n, offset_n, bias_add_n])
-
- input_map = {
- input_n.name + ":0":
- np.reshape([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], input_shape)
- }
- self._RunTestsForQuantizedInputRange(float_graph_def, input_map,
- [bias_add_n.name], [-1, 20.])
- self._RunTestsForQuantizedInputRange(float_graph_def, input_map,
- [bias_add_n.name], [0, 12.])
-
- def test_quantized_input_range_mat_mul(self):
- shapes = [[3, 2], [2, 4]]
- inputs = []
- for i, shape in enumerate(shapes):
- node = quantize_graph.create_node("Placeholder", "input_%s" % i, [])
- quantize_graph.set_attr_dtype(node, "dtype", dtypes.float32)
- quantize_graph.set_attr_shape(node, "shape", shape)
- inputs.append(node)
- mat_mul_node = quantize_graph.create_node("MatMul", "mat_mul",
- [n.name for n in inputs])
- quantize_graph.set_attr_dtype(mat_mul_node, "T", dtypes.float32)
-
- float_graph_def = graph_pb2.GraphDef()
- float_graph_def.node.extend(inputs + [mat_mul_node])
-
- input_map = {
- inputs[0].name + ":0":
- np.reshape([1, 2, 3, 4, 5, 6], shapes[0]),
- inputs[1].name + ":0":
- np.reshape([.8, .7, .6, .5, .4, .3, .2, .1], shapes[1])
- }
- self._RunTestsForQuantizedInputRange(float_graph_def, input_map,
- [mat_mul_node.name], [-1, 20.])
- self._RunTestsForQuantizedInputRange(float_graph_def, input_map,
- [mat_mul_node.name], [0, 6.])
-
- def _RunTestsForQuantizedInputRange(self, float_graph_def, input_map,
- output_names, input_range):
- if sys.version_info[0] == 3:
- # uint8->quint8 conversion for numpy is not working currently.
- return
-
- quantized_input_map = {}
- for k, v in input_map.items():
- arr = [
- int(
- round((n - input_range[0]) * 255 / (input_range[1] - input_range[
- 0]))) for n in v.flat
- ]
- arr = np.array(arr, np.uint8)
- arr = arr.reshape(v.shape)
- arr = arr.astype(dtypes.quint8.as_numpy_dtype)
- quantized_input_map[k] = arr
- output_tensors = [output_name + ":0" for output_name in output_names]
- float_results = run_graph_def(float_graph_def, input_map, output_tensors)
-
- # Quantize treating the input as quantized in range <input_range>.
- rewriter = quantize_graph.GraphRewriter(float_graph_def, "eightbit",
- input_range)
- graph_def = rewriter.rewrite(output_names)
- results = run_graph_def(graph_def, quantized_input_map, output_tensors)
- for expected, result in zip(float_results, results):
- assert are_tensors_near(expected, result, .5)
- ops = [node.op for node in graph_def.node]
- self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
- self.assertEqual(len(output_names), ops.count("Dequantize"))
-
- # Quantize without treating input as quantized.
- rewriter = quantize_graph.GraphRewriter(
- float_graph_def, "eightbit", quantized_input_range=None)
- graph_def = rewriter.rewrite(output_names)
- results = run_graph_def(graph_def, input_map, output_tensors)
- for expected, result in zip(float_results, results):
- assert are_tensors_near(expected, result, .5)
- ops = [node.op for node in graph_def.node]
- self.assertEqual(
- len(input_map), ops.count("QuantizeV2") + ops.count("Quantize"))
- self.assertEqual(len(output_names), ops.count("Dequantize"))
-
- def test_bias_add_w_fake_quant_w_min_max_vars(self):
- input_node = quantize_graph.create_constant_node(
- "input",
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
- dtype=dtypes.float32,
- shape=[1, 1, 2, 5])
- offset_node = quantize_graph.create_constant_node(
- "offset", value=[1, 2, 3, 4, 5], dtype=dtypes.float32, shape=[5])
- bias_add_node = quantize_graph.create_node(
- "BiasAdd", "bias_add", [input_node.name, offset_node.name])
- quantize_graph.set_attr_dtype(bias_add_node, "T", dtypes.float32)
-
- min_node = quantize_graph.create_constant_node(
- "min_bias_add", value=-.5, dtype=dtypes.float32, shape=[])
- max_node = quantize_graph.create_constant_node(
- "max_bias_add", value=15.5, dtype=dtypes.float32, shape=[])
- fake_quant_node = quantize_graph.create_node(
- "FakeQuantWithMinMaxVars", "fake_quant",
- [bias_add_node.name, min_node.name, max_node.name])
-
- float_graph_def = graph_pb2.GraphDef()
- float_graph_def.node.extend([
- input_node, offset_node, bias_add_node, min_node, max_node,
- fake_quant_node
- ])
- test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True)
-
- # Verify there is only one Quantize and one Requantize op.
- # Pass in fallback_quantization_range, although it will have no effect
- # because the FakeQuantWithMinMaxVars are used instead.
- eightbit_rewriter = quantize_graph.GraphRewriter(
- float_graph_def,
- "eightbit",
- quantized_input_range=None,
- fallback_quantization_range=[-100, 100])
- eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name])
-
- ops = [node.op for node in eightbit_graph_def.node]
- node_names = [node.name for node in eightbit_graph_def.node]
- # No quantize since all inputs are const and can be quantized up-front.
- self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
-
- # One dequantize at the end.
- self.assertEqual(1, ops.count("Dequantize"))
-
- # The fallback constants are not in the graph.
- self.assertEqual(0, node_names.count("fallback_quantization_min_value"))
- self.assertEqual(0, node_names.count("fallback_quantization_max_value"))
-
- def test_bias_add_w_fallback_min_max_vars(self):
- input_node = quantize_graph.create_constant_node(
- "input",
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
- dtype=dtypes.float32,
- shape=[1, 1, 2, 5])
- offset_node = quantize_graph.create_constant_node(
- "offset", value=[1, 2, 3, 4, 5], dtype=dtypes.float32, shape=[5])
- bias_add_node = quantize_graph.create_node(
- "BiasAdd", "bias_add", [input_node.name, offset_node.name])
- quantize_graph.set_attr_dtype(bias_add_node, "T", dtypes.float32)
-
- float_graph_def = graph_pb2.GraphDef()
- float_graph_def.node.extend([input_node, offset_node, bias_add_node])
- test_graph(float_graph_def, {}, [bias_add_node.name], log_graph=True)
-
- # Verify there is only one Quantize, one Requantize op, and no
- # RequantizationRange op.
- eightbit_rewriter = quantize_graph.GraphRewriter(
- float_graph_def,
- "eightbit",
- quantized_input_range=None,
- fallback_quantization_range=[-.5, 15.5])
- eightbit_graph_def = eightbit_rewriter.rewrite([bias_add_node.name])
-
- ops = [node.op for node in eightbit_graph_def.node]
- node_names = [node.name for node in eightbit_graph_def.node]
- # No quantize since all inputs are const and can be quantized up-front.
- self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
-
- # One dequantize at the end.
- self.assertEqual(1, ops.count("Dequantize"))
-
- # No RequantizationRange
- self.assertEqual(0, ops.count("RequantizationRange"))
-
- # The fallback constants are in the graph.
- self.assertEqual(1, node_names.count("fallback_quantization_min_value"))
- self.assertEqual(1, node_names.count("fallback_quantization_max_value"))
-
- def test_remove_redundant_quantization(self):
- a_constant_name = "a_constant"
- a_constant_min_name = "a_constant_min"
- a_constant_max_name = "a_constant_max"
- a_dequantize_name = "a_dequantize"
- a_quantize_name = "a_quantize"
- b_constant_name = "b_constant"
- b_constant_min_name = "b_constant_min"
- b_constant_max_name = "b_constant_max"
- b_dequantize_name = "b_dequantize"
- b_quantize_name = "b_quantize"
- mat_mul_name = "mat_mul"
- graph_def = graph_pb2.GraphDef()
- a_constant = quantize_graph.create_constant_node(
- a_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
- graph_def.node.extend([a_constant])
- a_constant_min = quantize_graph.create_constant_node(
- a_constant_min_name, value=2, dtype=dtypes.float32, shape=[])
- graph_def.node.extend([a_constant_min])
- a_constant_max = quantize_graph.create_constant_node(
- a_constant_max_name, value=2, dtype=dtypes.float32, shape=[])
- graph_def.node.extend([a_constant_max])
- a_dequantize_node = quantize_graph.create_node(
- "Dequantize", a_dequantize_name,
- [a_constant_name, a_constant_min_name, a_constant_max_name])
- quantize_graph.set_attr_dtype(a_dequantize_node, "T", dtypes.uint8)
- graph_def.node.extend([a_dequantize_node])
- a_quantize_node = quantize_graph.create_node(
- "QuantizeV2", a_quantize_name,
- [a_dequantize_name, a_dequantize_name + ":1", a_dequantize_name + ":2"])
- quantize_graph.set_attr_dtype(a_quantize_node, "T", dtypes.uint8)
- graph_def.node.extend([a_quantize_node])
- b_constant = quantize_graph.create_constant_node(
- b_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
- graph_def.node.extend([b_constant])
- b_constant_min = quantize_graph.create_constant_node(
- b_constant_min_name, value=3, dtype=dtypes.float32, shape=[])
- graph_def.node.extend([b_constant_min])
- b_constant_max = quantize_graph.create_constant_node(
- b_constant_max_name, value=3, dtype=dtypes.float32, shape=[])
- graph_def.node.extend([b_constant_max])
- b_dequantize_node = quantize_graph.create_node(
- "Dequantize", b_dequantize_name,
- [b_constant_name, b_constant_min_name, b_constant_max_name])
- quantize_graph.set_attr_dtype(b_dequantize_node, "T", dtypes.uint8)
- graph_def.node.extend([b_dequantize_node])
- b_quantize_node = quantize_graph.create_node(
- "QuantizeV2", b_quantize_name,
- [b_dequantize_name, b_dequantize_name + ":1", b_dequantize_name + ":2"])
- quantize_graph.set_attr_dtype(b_quantize_node, "T", dtypes.uint8)
- graph_def.node.extend([b_quantize_node])
- mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [
- a_quantize_name, b_quantize_name, a_quantize_name + ":1",
- a_quantize_name + ":2", b_quantize_name + ":1", b_quantize_name + ":2"
- ])
- quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8)
- quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32)
- graph_def.node.extend([mat_mul_node])
-
- expected_output = graph_pb2.GraphDef()
- a_constant = quantize_graph.create_constant_node(
- a_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
- expected_output.node.extend([a_constant])
- a_constant_min = quantize_graph.create_constant_node(
- a_constant_min_name, value=2, dtype=dtypes.float32, shape=[])
- expected_output.node.extend([a_constant_min])
- a_constant_max = quantize_graph.create_constant_node(
- a_constant_max_name, value=2, dtype=dtypes.float32, shape=[])
- expected_output.node.extend([a_constant_max])
- b_constant = quantize_graph.create_constant_node(
- b_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
- expected_output.node.extend([b_constant])
- b_constant_min = quantize_graph.create_constant_node(
- b_constant_min_name, value=3, dtype=dtypes.float32, shape=[])
- expected_output.node.extend([b_constant_min])
- b_constant_max = quantize_graph.create_constant_node(
- b_constant_max_name, value=3, dtype=dtypes.float32, shape=[])
- expected_output.node.extend([b_constant_max])
- mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [
- a_constant_name, b_constant_name, a_constant_min_name,
- a_constant_max_name, b_constant_min_name, b_constant_max_name
- ])
- quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8)
- quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32)
- expected_output.node.extend([mat_mul_node])
- expected_output.versions.CopyFrom(graph_def.versions)
- expected_output.library.CopyFrom(graph_def.library)
-
- rewriter = quantize_graph.GraphRewriter(
- graph_def, [mat_mul_name], quantized_input_range=None)
- output = rewriter.remove_redundant_quantization(graph_def)
- stripped_output = graph_util.extract_sub_graph(output, [mat_mul_name])
- self.assertProtoEquals(expected_output, stripped_output)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/tools/test/check_futures_test.py b/tensorflow/tools/test/check_futures_test.py
index 9181c9bd4a..a883ce221f 100644
--- a/tensorflow/tools/test/check_futures_test.py
+++ b/tensorflow/tools/test/check_futures_test.py
@@ -37,6 +37,7 @@ BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
FUTURES_PATTERN = re.compile(r'^from __future__ import (\w+)\s*$')
FUTURES_PATTERN_2 = re.compile(
r'^from __future__ import (\w+), (\w+), (\w+)\s*$')
+FUTURES_PATTERN_3 = re.compile(r'^from __future__ import (\w+) as \w+\s*$')
REQUIRED_FUTURES = frozenset(['absolute_import', 'division', 'print_function'])
WHITELIST = [
@@ -59,6 +60,8 @@ def check_file(path, old_division):
for line in open(path, encoding='utf-8') if six.PY3 else open(path):
count += 1
m = FUTURES_PATTERN.match(line)
+ if not m:
+ m = FUTURES_PATTERN_3.match(line)
if m:
futures.add(m.group(1))
else:
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 4ca083c8a3..70bade060e 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -1,6 +1,7 @@
# TensorFlow external dependencies that can be loaded in WORKSPACE files.
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
+load("//third_party/gpus:rocm_configure.bzl", "rocm_configure")
load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
load("//third_party:nccl/nccl_configure.bzl", "nccl_configure")
load("//third_party/mkl:build_defs.bzl", "mkl_repository")
@@ -20,9 +21,11 @@ load(
"def_file_filter_configure",
)
load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
+load("//third_party/icu:workspace.bzl", icu = "repo")
def initialize_third_party():
flatbuffers()
+ icu()
# Sanitize a dependency so that it works correctly from code that includes
# TensorFlow as a submodule.
@@ -43,6 +46,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
sycl_configure(name = "local_config_sycl")
syslibs_configure(name = "local_config_syslibs")
python_configure(name = "local_config_python")
+ rocm_configure(name = "local_config_rocm")
initialize_third_party()
@@ -53,39 +57,39 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
# Point //external/local_config_arm_compiler to //external/arm_compiler
arm_compiler_configure(
name = "local_config_arm_compiler",
- remote_config_repo = "../arm_compiler",
build_file = clean_dep("//third_party/toolchains/cpus/arm:BUILD"),
+ remote_config_repo = "../arm_compiler",
)
mkl_repository(
name = "mkl_linux",
+ build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
+ sha256 = "e2233534a9d15c387e22260997af4312a39e9f86f791768409be273b5453c4e6",
+ strip_prefix = "mklml_lnx_2019.0.20180710",
urls = [
"https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.16/mklml_lnx_2019.0.20180710.tgz",
"https://github.com/intel/mkl-dnn/releases/download/v0.16/mklml_lnx_2019.0.20180710.tgz",
],
- sha256 = "e2233534a9d15c387e22260997af4312a39e9f86f791768409be273b5453c4e6",
- strip_prefix = "mklml_lnx_2019.0.20180710",
- build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
)
mkl_repository(
name = "mkl_windows",
+ build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
+ sha256 = "3fdcff17b018a0082491adf3ba143358265336a801646e46e0191ec8d58d24a2",
+ strip_prefix = "mklml_win_2019.0.20180710",
urls = [
"https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.16/mklml_win_2019.0.20180710.zip",
"https://github.com/intel/mkl-dnn/releases/download/v0.16/mklml_win_2019.0.20180710.zip",
],
- sha256 = "3fdcff17b018a0082491adf3ba143358265336a801646e46e0191ec8d58d24a2",
- strip_prefix = "mklml_win_2019.0.20180710",
- build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
)
mkl_repository(
name = "mkl_darwin",
+ build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
+ sha256 = "411a30014a938eb83fb9f37b3dbe8e371b106fc1dd621fc23123cadc72737ce6",
+ strip_prefix = "mklml_mac_2019.0.20180710",
urls = [
"https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.16/mklml_mac_2019.0.20180710.tgz",
"https://github.com/intel/mkl-dnn/releases/download/v0.16/mklml_mac_2019.0.20180710.tgz",
],
- sha256 = "411a30014a938eb83fb9f37b3dbe8e371b106fc1dd621fc23123cadc72737ce6",
- strip_prefix = "mklml_mac_2019.0.20180710",
- build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
)
if path_prefix:
@@ -94,39 +98,40 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "mkl_dnn",
+ build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"),
+ sha256 = "363cc9239eacf8e7917753c6d8c94f767e4cd049160d0654a61ef32d5e1b3049",
+ strip_prefix = "mkl-dnn-4e333787e0d66a1dca1218e99a891d493dbc8ef1",
urls = [
"https://mirror.bazel.build/github.com/intel/mkl-dnn/archive/4e333787e0d66a1dca1218e99a891d493dbc8ef1.tar.gz",
"https://github.com/intel/mkl-dnn/archive/4e333787e0d66a1dca1218e99a891d493dbc8ef1.tar.gz",
],
- sha256 = "363cc9239eacf8e7917753c6d8c94f767e4cd049160d0654a61ef32d5e1b3049",
- strip_prefix = "mkl-dnn-4e333787e0d66a1dca1218e99a891d493dbc8ef1",
- build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"),
)
tf_http_archive(
name = "com_google_absl",
+ build_file = clean_dep("//third_party:com_google_absl.BUILD"),
+ sha256 = "278a1af58b633be886fe81bf7061dca6b5fea99566850d1319fffdaa1a061792",
+ strip_prefix = "abseil-cpp-e291c279e458761e77a69b09b129d3d1e81f1e80",
urls = [
- "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/8ff1374008259719b54a8cb128ef951c02da164c.tar.gz",
- "https://github.com/abseil/abseil-cpp/archive/8ff1374008259719b54a8cb128ef951c02da164c.tar.gz",
+ "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/e291c279e458761e77a69b09b129d3d1e81f1e80.tar.gz",
+ "https://github.com/abseil/abseil-cpp/archive/e291c279e458761e77a69b09b129d3d1e81f1e80.tar.gz",
],
- sha256 = "006931f9705484041eed65189038f87931a87cff200bb296f94b3d42339c4cd9",
- strip_prefix = "abseil-cpp-8ff1374008259719b54a8cb128ef951c02da164c",
- build_file = clean_dep("//third_party:com_google_absl.BUILD"),
)
tf_http_archive(
name = "eigen_archive",
+ build_file = clean_dep("//third_party:eigen.BUILD"),
+ sha256 = "d956415d784fa4e42b6a2a45c32556d6aec9d0a3d8ef48baee2522ab762556a9",
+ strip_prefix = "eigen-eigen-fd6845384b86",
urls = [
"https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz",
"https://bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz",
],
- sha256 = "d956415d784fa4e42b6a2a45c32556d6aec9d0a3d8ef48baee2522ab762556a9",
- strip_prefix = "eigen-eigen-fd6845384b86",
- build_file = clean_dep("//third_party:eigen.BUILD"),
)
tf_http_archive(
name = "arm_compiler",
+ build_file = clean_dep("//:arm_compiler.BUILD"),
sha256 = "970285762565c7890c6c087d262b0a18286e7d0384f13a37786d8521773bc969",
strip_prefix = "tools-0e906ebc527eab1cdbf7adabff5b474da9562e9f/arm-bcm2708/arm-rpi-4.9.3-linux-gnueabihf",
urls = [
@@ -135,223 +140,233 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
# remove the whitelist entry in third_party/repo.bzl.
# "https://github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz",
],
- build_file = clean_dep("//:arm_compiler.BUILD"),
)
tf_http_archive(
name = "libxsmm_archive",
+ build_file = clean_dep("//third_party:libxsmm.BUILD"),
+ sha256 = "cd8532021352b4a0290d209f7f9bfd7c2411e08286a893af3577a43457287bfa",
+ strip_prefix = "libxsmm-1.9",
urls = [
"https://mirror.bazel.build/github.com/hfp/libxsmm/archive/1.9.tar.gz",
"https://github.com/hfp/libxsmm/archive/1.9.tar.gz",
],
- sha256 = "cd8532021352b4a0290d209f7f9bfd7c2411e08286a893af3577a43457287bfa",
- strip_prefix = "libxsmm-1.9",
- build_file = clean_dep("//third_party:libxsmm.BUILD"),
)
tf_http_archive(
name = "ortools_archive",
+ build_file = clean_dep("//third_party:ortools.BUILD"),
+ sha256 = "d025a95f78b5fc5eaa4da5f395f23d11c23cf7dbd5069f1f627f002de87b86b9",
+ strip_prefix = "or-tools-6.7.2/src",
urls = [
"https://mirror.bazel.build/github.com/google/or-tools/archive/v6.7.2.tar.gz",
"https://github.com/google/or-tools/archive/v6.7.2.tar.gz",
],
- sha256 = "d025a95f78b5fc5eaa4da5f395f23d11c23cf7dbd5069f1f627f002de87b86b9",
- strip_prefix = "or-tools-6.7.2/src",
- build_file = clean_dep("//third_party:ortools.BUILD"),
)
tf_http_archive(
name = "com_googlesource_code_re2",
+ sha256 = "803c7811146edeef8f91064de37c6f19136ff01a2a8cdb3230e940b2fd9f07fe",
+ strip_prefix = "re2-2018-07-01",
+ system_build_file = clean_dep("//third_party/systemlibs:re2.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/google/re2/archive/2018-07-01.tar.gz",
"https://github.com/google/re2/archive/2018-07-01.tar.gz",
],
- sha256 = "803c7811146edeef8f91064de37c6f19136ff01a2a8cdb3230e940b2fd9f07fe",
- strip_prefix = "re2-2018-07-01",
- system_build_file = clean_dep("//third_party/systemlibs:re2.BUILD"),
)
tf_http_archive(
name = "com_github_googlecloudplatform_google_cloud_cpp",
+ sha256 = "fdd3b3aecce60987e5525e55bf3a21d68a8695320bd5b980775af6507eec3944",
+ strip_prefix = "google-cloud-cpp-14760a86c4ffab9943b476305c4fe927ad95db1c",
+ system_build_file = clean_dep("//third_party/systemlibs:google_cloud_cpp.BUILD"),
+ system_link_files = {
+ "//third_party/systemlibs:google_cloud_cpp.google.cloud.bigtable.BUILD": "google/cloud/bigtable/BUILD",
+ },
urls = [
"https://mirror.bazel.build/github.com/GoogleCloudPlatform/google-cloud-cpp/archive/14760a86c4ffab9943b476305c4fe927ad95db1c.tar.gz",
"https://github.com/GoogleCloudPlatform/google-cloud-cpp/archive/14760a86c4ffab9943b476305c4fe927ad95db1c.tar.gz",
],
- sha256 = "fdd3b3aecce60987e5525e55bf3a21d68a8695320bd5b980775af6507eec3944",
- strip_prefix = "google-cloud-cpp-14760a86c4ffab9943b476305c4fe927ad95db1c",
)
tf_http_archive(
name = "com_github_googleapis_googleapis",
+ build_file = clean_dep("//third_party:googleapis.BUILD"),
+ sha256 = "824870d87a176f26bcef663e92051f532fac756d1a06b404055dc078425f4378",
+ strip_prefix = "googleapis-f81082ea1e2f85c43649bee26e0d9871d4b41cdb",
+ system_build_file = clean_dep("//third_party/systemlibs:googleapis.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip",
"https://github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip",
],
- sha256 = "824870d87a176f26bcef663e92051f532fac756d1a06b404055dc078425f4378",
- strip_prefix = "googleapis-f81082ea1e2f85c43649bee26e0d9871d4b41cdb",
- build_file = clean_dep("//third_party:googleapis.BUILD"),
)
tf_http_archive(
name = "gemmlowp",
+ sha256 = "b87faa7294dfcc5d678f22a59d2c01ca94ea1e2a3b488c38a95a67889ed0a658",
+ strip_prefix = "gemmlowp-38ebac7b059e84692f53e5938f97a9943c120d98",
urls = [
"https://mirror.bazel.build/github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip",
"https://github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip",
],
- sha256 = "b87faa7294dfcc5d678f22a59d2c01ca94ea1e2a3b488c38a95a67889ed0a658",
- strip_prefix = "gemmlowp-38ebac7b059e84692f53e5938f97a9943c120d98",
)
tf_http_archive(
name = "farmhash_archive",
+ build_file = clean_dep("//third_party:farmhash.BUILD"),
+ sha256 = "6560547c63e4af82b0f202cb710ceabb3f21347a4b996db565a411da5b17aba0",
+ strip_prefix = "farmhash-816a4ae622e964763ca0862d9dbd19324a1eaf45",
urls = [
"https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz",
"https://github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz",
],
- sha256 = "6560547c63e4af82b0f202cb710ceabb3f21347a4b996db565a411da5b17aba0",
- strip_prefix = "farmhash-816a4ae622e964763ca0862d9dbd19324a1eaf45",
- build_file = clean_dep("//third_party:farmhash.BUILD"),
)
tf_http_archive(
name = "highwayhash",
+ build_file = clean_dep("//third_party:highwayhash.BUILD"),
+ sha256 = "9c3e0e87d581feeb0c18d814d98f170ff23e62967a2bd6855847f0b2fe598a37",
+ strip_prefix = "highwayhash-fd3d9af80465e4383162e4a7c5e2f406e82dd968",
urls = [
"http://mirror.bazel.build/github.com/google/highwayhash/archive/fd3d9af80465e4383162e4a7c5e2f406e82dd968.tar.gz",
"https://github.com/google/highwayhash/archive/fd3d9af80465e4383162e4a7c5e2f406e82dd968.tar.gz",
],
- sha256 = "9c3e0e87d581feeb0c18d814d98f170ff23e62967a2bd6855847f0b2fe598a37",
- strip_prefix = "highwayhash-fd3d9af80465e4383162e4a7c5e2f406e82dd968",
- build_file = clean_dep("//third_party:highwayhash.BUILD"),
)
tf_http_archive(
name = "nasm",
+ build_file = clean_dep("//third_party:nasm.BUILD"),
+ sha256 = "63ec86477ad3f0f6292325fd89e1d93aea2e2fd490070863f17d48f7cd387011",
+ strip_prefix = "nasm-2.13.03",
+ system_build_file = clean_dep("//third_party/systemlibs:nasm.BUILD"),
urls = [
"https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
"http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.13.03.tar.bz2/sha512/d7a6b4cee8dfd603d8d4c976e5287b5cc542fa0b466ff989b743276a6e28114e64289bf02a7819eca63142a5278aa6eed57773007e5f589e15768e6456a8919d/nasm-2.13.03.tar.bz2",
"http://www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
],
- sha256 = "63ec86477ad3f0f6292325fd89e1d93aea2e2fd490070863f17d48f7cd387011",
- strip_prefix = "nasm-2.13.03",
- build_file = clean_dep("//third_party:nasm.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:nasm.BUILD"),
)
tf_http_archive(
name = "jpeg",
+ build_file = clean_dep("//third_party/jpeg:jpeg.BUILD"),
+ sha256 = "f892fff427ab3adffc289363eac26d197ce3ccacefe5f5822377348a8166069b",
+ strip_prefix = "libjpeg-turbo-2.0.0",
+ system_build_file = clean_dep("//third_party/systemlibs:jpeg.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz",
"https://github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz",
],
- sha256 = "f892fff427ab3adffc289363eac26d197ce3ccacefe5f5822377348a8166069b",
- strip_prefix = "libjpeg-turbo-2.0.0",
- build_file = clean_dep("//third_party/jpeg:jpeg.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:jpeg.BUILD"),
)
tf_http_archive(
name = "png_archive",
+ build_file = clean_dep("//third_party:png.BUILD"),
+ patch_file = clean_dep("//third_party:png_fix_rpi.patch"),
+ sha256 = "e45ce5f68b1d80e2cb9a2b601605b374bdf51e1798ef1c2c2bd62131dfcf9eef",
+ strip_prefix = "libpng-1.6.34",
+ system_build_file = clean_dep("//third_party/systemlibs:png.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.34.tar.gz",
"https://github.com/glennrp/libpng/archive/v1.6.34.tar.gz",
],
- sha256 = "e45ce5f68b1d80e2cb9a2b601605b374bdf51e1798ef1c2c2bd62131dfcf9eef",
- strip_prefix = "libpng-1.6.34",
- build_file = clean_dep("//third_party:png.BUILD"),
- patch_file = clean_dep("//third_party:png_fix_rpi.patch"),
- system_build_file = clean_dep("//third_party/systemlibs:png.BUILD"),
)
tf_http_archive(
name = "org_sqlite",
+ build_file = clean_dep("//third_party:sqlite.BUILD"),
+ sha256 = "ad68c1216c3a474cf360c7581a4001e952515b3649342100f2d7ca7c8e313da6",
+ strip_prefix = "sqlite-amalgamation-3240000",
+ system_build_file = clean_dep("//third_party/systemlibs:sqlite.BUILD"),
urls = [
"https://mirror.bazel.build/www.sqlite.org/2018/sqlite-amalgamation-3240000.zip",
"https://www.sqlite.org/2018/sqlite-amalgamation-3240000.zip",
],
- sha256 = "ad68c1216c3a474cf360c7581a4001e952515b3649342100f2d7ca7c8e313da6",
- strip_prefix = "sqlite-amalgamation-3240000",
- build_file = clean_dep("//third_party:sqlite.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:sqlite.BUILD"),
)
tf_http_archive(
name = "gif_archive",
+ build_file = clean_dep("//third_party:gif.BUILD"),
+ sha256 = "34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1",
+ strip_prefix = "giflib-5.1.4",
+ system_build_file = clean_dep("//third_party/systemlibs:gif.BUILD"),
urls = [
"https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz",
"http://pilotfiber.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz",
],
- sha256 = "34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1",
- strip_prefix = "giflib-5.1.4",
- build_file = clean_dep("//third_party:gif.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:gif.BUILD"),
)
tf_http_archive(
name = "six_archive",
+ build_file = clean_dep("//third_party:six.BUILD"),
+ sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a",
+ strip_prefix = "six-1.10.0",
+ system_build_file = clean_dep("//third_party/systemlibs:six.BUILD"),
urls = [
"https://mirror.bazel.build/pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz",
"https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz",
],
- sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a",
- strip_prefix = "six-1.10.0",
- build_file = clean_dep("//third_party:six.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:six.BUILD"),
)
tf_http_archive(
name = "astor_archive",
+ build_file = clean_dep("//third_party:astor.BUILD"),
+ sha256 = "ff6d2e2962d834acb125cc4dcc80c54a8c17c253f4cc9d9c43b5102a560bb75d",
+ strip_prefix = "astor-0.6.2",
+ system_build_file = clean_dep("//third_party/systemlibs:astor.BUILD"),
urls = [
"https://mirror.bazel.build/pypi.python.org/packages/d8/be/c4276b3199ec3feee2a88bc64810fbea8f26d961e0a4cd9c68387a9f35de/astor-0.6.2.tar.gz",
"https://pypi.python.org/packages/d8/be/c4276b3199ec3feee2a88bc64810fbea8f26d961e0a4cd9c68387a9f35de/astor-0.6.2.tar.gz",
],
- sha256 = "ff6d2e2962d834acb125cc4dcc80c54a8c17c253f4cc9d9c43b5102a560bb75d",
- strip_prefix = "astor-0.6.2",
- build_file = clean_dep("//third_party:astor.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:astor.BUILD"),
)
tf_http_archive(
name = "gast_archive",
+ build_file = clean_dep("//third_party:gast.BUILD"),
+ sha256 = "7068908321ecd2774f145193c4b34a11305bd104b4551b09273dfd1d6a374930",
+ strip_prefix = "gast-0.2.0",
+ system_build_file = clean_dep("//third_party/systemlibs:gast.BUILD"),
urls = [
"https://mirror.bazel.build/pypi.python.org/packages/5c/78/ff794fcae2ce8aa6323e789d1f8b3b7765f601e7702726f430e814822b96/gast-0.2.0.tar.gz",
"https://pypi.python.org/packages/5c/78/ff794fcae2ce8aa6323e789d1f8b3b7765f601e7702726f430e814822b96/gast-0.2.0.tar.gz",
],
- sha256 = "7068908321ecd2774f145193c4b34a11305bd104b4551b09273dfd1d6a374930",
- strip_prefix = "gast-0.2.0",
- build_file = clean_dep("//third_party:gast.BUILD"),
)
tf_http_archive(
name = "termcolor_archive",
+ build_file = clean_dep("//third_party:termcolor.BUILD"),
+ sha256 = "1d6d69ce66211143803fbc56652b41d73b4a400a2891d7bf7a1cdf4c02de613b",
+ strip_prefix = "termcolor-1.1.0",
+ system_build_file = clean_dep("//third_party/systemlibs:termcolor.BUILD"),
urls = [
"https://mirror.bazel.build/pypi.python.org/packages/8a/48/a76be51647d0eb9f10e2a4511bf3ffb8cc1e6b14e9e4fab46173aa79f981/termcolor-1.1.0.tar.gz",
"https://pypi.python.org/packages/8a/48/a76be51647d0eb9f10e2a4511bf3ffb8cc1e6b14e9e4fab46173aa79f981/termcolor-1.1.0.tar.gz",
],
- sha256 = "1d6d69ce66211143803fbc56652b41d73b4a400a2891d7bf7a1cdf4c02de613b",
- strip_prefix = "termcolor-1.1.0",
- build_file = clean_dep("//third_party:termcolor.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:termcolor.BUILD"),
)
tf_http_archive(
name = "absl_py",
+ sha256 = "95160f778a62c7a60ddeadc7bf2d83f85a23a27359814aca12cf949e896fa82c",
+ strip_prefix = "abseil-py-pypi-v0.2.2",
+ system_build_file = clean_dep("//third_party/systemlibs:absl_py.BUILD"),
+ system_link_files = {
+ "//third_party/systemlibs:absl_py.absl.flags.BUILD": "absl/flags/BUILD",
+ "//third_party/systemlibs:absl_py.absl.testing.BUILD": "absl/testing/BUILD",
+ },
urls = [
"https://mirror.bazel.build/github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz",
"https://github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz",
],
- sha256 = "95160f778a62c7a60ddeadc7bf2d83f85a23a27359814aca12cf949e896fa82c",
- strip_prefix = "abseil-py-pypi-v0.2.2",
)
tf_http_archive(
name = "org_python_pypi_backports_weakref",
+ build_file = clean_dep("//third_party:backports_weakref.BUILD"),
+ sha256 = "8813bf712a66b3d8b85dc289e1104ed220f1878cf981e2fe756dfaabe9a82892",
+ strip_prefix = "backports.weakref-1.0rc1/src",
urls = [
"https://mirror.bazel.build/pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz",
"https://pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz",
],
- sha256 = "8813bf712a66b3d8b85dc289e1104ed220f1878cf981e2fe756dfaabe9a82892",
- strip_prefix = "backports.weakref-1.0rc1/src",
- build_file = clean_dep("//third_party:backports_weakref.BUILD"),
)
filegroup_external(
@@ -374,9 +389,9 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "protobuf_archive",
- urls = PROTOBUF_URLS,
sha256 = PROTOBUF_SHA256,
strip_prefix = PROTOBUF_STRIP_PREFIX,
+ urls = PROTOBUF_URLS,
)
# We need to import the protobuf library under the names com_google_protobuf
@@ -384,221 +399,222 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
# Unfortunately there is no way to alias http_archives at the moment.
tf_http_archive(
name = "com_google_protobuf",
- urls = PROTOBUF_URLS,
sha256 = PROTOBUF_SHA256,
strip_prefix = PROTOBUF_STRIP_PREFIX,
+ urls = PROTOBUF_URLS,
)
tf_http_archive(
name = "com_google_protobuf_cc",
- urls = PROTOBUF_URLS,
sha256 = PROTOBUF_SHA256,
strip_prefix = PROTOBUF_STRIP_PREFIX,
+ urls = PROTOBUF_URLS,
)
tf_http_archive(
name = "nsync",
+ sha256 = "692f9b30e219f71a6371b98edd39cef3cbda35ac3abc4cd99ce19db430a5591a",
+ strip_prefix = "nsync-1.20.1",
+ system_build_file = clean_dep("//third_party/systemlibs:nsync.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/google/nsync/archive/1.20.1.tar.gz",
"https://github.com/google/nsync/archive/1.20.1.tar.gz",
],
- sha256 = "692f9b30e219f71a6371b98edd39cef3cbda35ac3abc4cd99ce19db430a5591a",
- strip_prefix = "nsync-1.20.1",
- system_build_file = clean_dep("//third_party/systemlibs:nsync.BUILD"),
)
tf_http_archive(
name = "com_google_googletest",
+ sha256 = "353ab86e35cea1cd386115279cf4b16695bbf21b897bfbf2721cf4cb5f64ade8",
+ strip_prefix = "googletest-997d343dd680e541ef96ce71ee54a91daf2577a0",
urls = [
"https://mirror.bazel.build/github.com/google/googletest/archive/997d343dd680e541ef96ce71ee54a91daf2577a0.zip",
"https://github.com/google/googletest/archive/997d343dd680e541ef96ce71ee54a91daf2577a0.zip",
],
- sha256 = "353ab86e35cea1cd386115279cf4b16695bbf21b897bfbf2721cf4cb5f64ade8",
- strip_prefix = "googletest-997d343dd680e541ef96ce71ee54a91daf2577a0",
)
tf_http_archive(
name = "com_github_gflags_gflags",
+ sha256 = "ae27cdbcd6a2f935baa78e4f21f675649271634c092b1be01469440495609d0e",
+ strip_prefix = "gflags-2.2.1",
urls = [
"https://mirror.bazel.build/github.com/gflags/gflags/archive/v2.2.1.tar.gz",
"https://github.com/gflags/gflags/archive/v2.2.1.tar.gz",
],
- sha256 = "ae27cdbcd6a2f935baa78e4f21f675649271634c092b1be01469440495609d0e",
- strip_prefix = "gflags-2.2.1",
)
tf_http_archive(
name = "pcre",
+ build_file = clean_dep("//third_party:pcre.BUILD"),
sha256 = "69acbc2fbdefb955d42a4c606dfde800c2885711d2979e356c0636efde9ec3b5",
+ strip_prefix = "pcre-8.42",
+ system_build_file = clean_dep("//third_party/systemlibs:pcre.BUILD"),
urls = [
"https://mirror.bazel.build/ftp.exim.org/pub/pcre/pcre-8.42.tar.gz",
"http://ftp.exim.org/pub/pcre/pcre-8.42.tar.gz",
],
- strip_prefix = "pcre-8.42",
- build_file = clean_dep("//third_party:pcre.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:pcre.BUILD"),
)
tf_http_archive(
name = "swig",
+ build_file = clean_dep("//third_party:swig.BUILD"),
sha256 = "58a475dbbd4a4d7075e5fe86d4e54c9edde39847cdb96a3053d87cb64a23a453",
+ strip_prefix = "swig-3.0.8",
+ system_build_file = clean_dep("//third_party/systemlibs:swig.BUILD"),
urls = [
"https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz",
"http://ufpr.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz",
"http://pilotfiber.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz",
],
- strip_prefix = "swig-3.0.8",
- build_file = clean_dep("//third_party:swig.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:swig.BUILD"),
)
tf_http_archive(
name = "curl",
+ build_file = clean_dep("//third_party:curl.BUILD"),
sha256 = "e9c37986337743f37fd14fe8737f246e97aec94b39d1b71e8a5973f72a9fc4f5",
+ strip_prefix = "curl-7.60.0",
+ system_build_file = clean_dep("//third_party/systemlibs:curl.BUILD"),
urls = [
"https://mirror.bazel.build/curl.haxx.se/download/curl-7.60.0.tar.gz",
"https://curl.haxx.se/download/curl-7.60.0.tar.gz",
],
- strip_prefix = "curl-7.60.0",
- build_file = clean_dep("//third_party:curl.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:curl.BUILD"),
)
tf_http_archive(
name = "grpc",
+ sha256 = "50db9cf2221354485eb7c3bd55a4c27190caef7048a2a1a15fbe60a498f98b44",
+ strip_prefix = "grpc-1.13.0",
+ system_build_file = clean_dep("//third_party/systemlibs:grpc.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/grpc/grpc/archive/v1.13.0.tar.gz",
"https://github.com/grpc/grpc/archive/v1.13.0.tar.gz",
],
- sha256 = "50db9cf2221354485eb7c3bd55a4c27190caef7048a2a1a15fbe60a498f98b44",
- strip_prefix = "grpc-1.13.0",
- system_build_file = clean_dep("//third_party/systemlibs:grpc.BUILD"),
)
tf_http_archive(
name = "linenoise",
+ build_file = clean_dep("//third_party:linenoise.BUILD"),
sha256 = "7f51f45887a3d31b4ce4fa5965210a5e64637ceac12720cfce7954d6a2e812f7",
+ strip_prefix = "linenoise-c894b9e59f02203dbe4e2be657572cf88c4230c3",
urls = [
"https://mirror.bazel.build/github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz",
"https://github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz",
],
- strip_prefix = "linenoise-c894b9e59f02203dbe4e2be657572cf88c4230c3",
- build_file = clean_dep("//third_party:linenoise.BUILD"),
)
# TODO(phawkins): currently, this rule uses an unofficial LLVM mirror.
# Switch to an official source of snapshots if/when possible.
tf_http_archive(
name = "llvm",
+ build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
+ sha256 = "a4f8bfe7e3e69069934a87e612a1d4d3b8b6af13e0f1213a42a6046e1bcd50d8",
+ strip_prefix = "llvm-d3429e96fe1e45b1dc0106463832523f37faf271",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/ad72545325c087661feb3512efa54ebe5f888736.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/ad72545325c087661feb3512efa54ebe5f888736.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/d3429e96fe1e45b1dc0106463832523f37faf271.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/d3429e96fe1e45b1dc0106463832523f37faf271.tar.gz",
],
- sha256 = "66ed69443af00fbf9b912edbb6bc0fa796a12766b5e9ad504eb6b20f813dc163",
- strip_prefix = "llvm-ad72545325c087661feb3512efa54ebe5f888736",
- build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
)
tf_http_archive(
name = "lmdb",
+ build_file = clean_dep("//third_party:lmdb.BUILD"),
+ sha256 = "f3927859882eb608868c8c31586bb7eb84562a40a6bf5cc3e13b6b564641ea28",
+ strip_prefix = "lmdb-LMDB_0.9.22/libraries/liblmdb",
+ system_build_file = clean_dep("//third_party/systemlibs:lmdb.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.22.tar.gz",
"https://github.com/LMDB/lmdb/archive/LMDB_0.9.22.tar.gz",
],
- sha256 = "f3927859882eb608868c8c31586bb7eb84562a40a6bf5cc3e13b6b564641ea28",
- strip_prefix = "lmdb-LMDB_0.9.22/libraries/liblmdb",
- build_file = clean_dep("//third_party:lmdb.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:lmdb.BUILD"),
)
tf_http_archive(
name = "jsoncpp_git",
+ build_file = clean_dep("//third_party:jsoncpp.BUILD"),
+ sha256 = "c49deac9e0933bcb7044f08516861a2d560988540b23de2ac1ad443b219afdb6",
+ strip_prefix = "jsoncpp-1.8.4",
+ system_build_file = clean_dep("//third_party/systemlibs:jsoncpp.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/open-source-parsers/jsoncpp/archive/1.8.4.tar.gz",
"https://github.com/open-source-parsers/jsoncpp/archive/1.8.4.tar.gz",
],
- sha256 = "c49deac9e0933bcb7044f08516861a2d560988540b23de2ac1ad443b219afdb6",
- strip_prefix = "jsoncpp-1.8.4",
- build_file = clean_dep("//third_party:jsoncpp.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:jsoncpp.BUILD"),
)
tf_http_archive(
name = "boringssl",
+ sha256 = "1188e29000013ed6517168600fc35a010d58c5d321846d6a6dfee74e4c788b45",
+ strip_prefix = "boringssl-7f634429a04abc48e2eb041c81c5235816c96514",
+ system_build_file = clean_dep("//third_party/systemlibs:boringssl.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/google/boringssl/archive/7f634429a04abc48e2eb041c81c5235816c96514.tar.gz",
"https://github.com/google/boringssl/archive/7f634429a04abc48e2eb041c81c5235816c96514.tar.gz",
],
- sha256 = "1188e29000013ed6517168600fc35a010d58c5d321846d6a6dfee74e4c788b45",
- strip_prefix = "boringssl-7f634429a04abc48e2eb041c81c5235816c96514",
)
tf_http_archive(
name = "zlib_archive",
+ build_file = clean_dep("//third_party:zlib.BUILD"),
+ sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1",
+ strip_prefix = "zlib-1.2.11",
+ system_build_file = clean_dep("//third_party/systemlibs:zlib.BUILD"),
urls = [
"https://mirror.bazel.build/zlib.net/zlib-1.2.11.tar.gz",
"https://zlib.net/zlib-1.2.11.tar.gz",
],
- sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1",
- strip_prefix = "zlib-1.2.11",
- build_file = clean_dep("//third_party:zlib.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:zlib.BUILD"),
)
tf_http_archive(
name = "fft2d",
+ build_file = clean_dep("//third_party/fft2d:fft2d.BUILD"),
+ sha256 = "52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296",
urls = [
"https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz",
"http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz",
],
- sha256 = "52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296",
- build_file = clean_dep("//third_party/fft2d:fft2d.BUILD"),
)
tf_http_archive(
name = "snappy",
+ build_file = clean_dep("//third_party:snappy.BUILD"),
+ sha256 = "3dfa02e873ff51a11ee02b9ca391807f0c8ea0529a4924afa645fbf97163f9d4",
+ strip_prefix = "snappy-1.1.7",
+ system_build_file = clean_dep("//third_party/systemlibs:snappy.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/google/snappy/archive/1.1.7.tar.gz",
"https://github.com/google/snappy/archive/1.1.7.tar.gz",
],
- sha256 = "3dfa02e873ff51a11ee02b9ca391807f0c8ea0529a4924afa645fbf97163f9d4",
- strip_prefix = "snappy-1.1.7",
- build_file = clean_dep("//third_party:snappy.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:snappy.BUILD"),
)
tf_http_archive(
name = "nccl_archive",
+ build_file = clean_dep("//third_party:nccl/nccl_archive.BUILD"),
+ sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176",
+ strip_prefix = "nccl-03d856977ecbaac87e598c0c4bafca96761b9ac7",
urls = [
"https://mirror.bazel.build/github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz",
"https://github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz",
],
- sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176",
- strip_prefix = "nccl-03d856977ecbaac87e598c0c4bafca96761b9ac7",
- build_file = clean_dep("//third_party:nccl/nccl_archive.BUILD"),
)
tf_http_archive(
name = "kafka",
+ build_file = clean_dep("//third_party:kafka/BUILD"),
+ patch_file = clean_dep("//third_party/kafka:config.patch"),
+ sha256 = "cc6ebbcd0a826eec1b8ce1f625ffe71b53ef3290f8192b6cae38412a958f4fd3",
+ strip_prefix = "librdkafka-0.11.5",
urls = [
"https://mirror.bazel.build/github.com/edenhill/librdkafka/archive/v0.11.5.tar.gz",
"https://github.com/edenhill/librdkafka/archive/v0.11.5.tar.gz",
],
- sha256 = "cc6ebbcd0a826eec1b8ce1f625ffe71b53ef3290f8192b6cae38412a958f4fd3",
- strip_prefix = "librdkafka-0.11.5",
- build_file = clean_dep("//third_party:kafka/BUILD"),
- patch_file = clean_dep("//third_party/kafka:config.patch"),
)
tf_http_archive(
name = "aws",
+ build_file = clean_dep("//third_party:aws.BUILD"),
+ sha256 = "b888d8ce5fc10254c3dd6c9020c7764dd53cf39cf011249d0b4deda895de1b7c",
+ strip_prefix = "aws-sdk-cpp-1.3.15",
urls = [
"https://mirror.bazel.build/github.com/aws/aws-sdk-cpp/archive/1.3.15.tar.gz",
"https://github.com/aws/aws-sdk-cpp/archive/1.3.15.tar.gz",
],
- sha256 = "b888d8ce5fc10254c3dd6c9020c7764dd53cf39cf011249d0b4deda895de1b7c",
- strip_prefix = "aws-sdk-cpp-1.3.15",
- build_file = clean_dep("//third_party:aws.BUILD"),
)
java_import_external(
@@ -628,14 +644,14 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "jemalloc",
+ build_file = clean_dep("//third_party:jemalloc.BUILD"),
+ sha256 = "3c8f25c02e806c3ce0ab5fb7da1817f89fc9732709024e2a81b6b82f7cc792a8",
+ strip_prefix = "jemalloc-4.4.0",
+ system_build_file = clean_dep("//third_party/systemlibs:jemalloc.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz",
"https://github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz",
],
- sha256 = "3c8f25c02e806c3ce0ab5fb7da1817f89fc9732709024e2a81b6b82f7cc792a8",
- strip_prefix = "jemalloc-4.4.0",
- build_file = clean_dep("//third_party:jemalloc.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:jemalloc.BUILD"),
)
java_import_external(
@@ -684,183 +700,196 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "com_google_pprof",
+ build_file = clean_dep("//third_party:pprof.BUILD"),
+ sha256 = "e0928ca4aa10ea1e0551e2d7ce4d1d7ea2d84b2abbdef082b0da84268791d0c4",
+ strip_prefix = "pprof-c0fb62ec88c411cc91194465e54db2632845b650",
urls = [
"https://mirror.bazel.build/github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz",
"https://github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz",
],
- sha256 = "e0928ca4aa10ea1e0551e2d7ce4d1d7ea2d84b2abbdef082b0da84268791d0c4",
- strip_prefix = "pprof-c0fb62ec88c411cc91194465e54db2632845b650",
- build_file = clean_dep("//third_party:pprof.BUILD"),
)
tf_http_archive(
name = "cub_archive",
+ build_file = clean_dep("//third_party:cub.BUILD"),
+ sha256 = "6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3",
+ strip_prefix = "cub-1.8.0",
urls = [
"https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.8.0.zip",
"https://github.com/NVlabs/cub/archive/1.8.0.zip",
],
- sha256 = "6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3",
- strip_prefix = "cub-1.8.0",
- build_file = clean_dep("//third_party:cub.BUILD"),
)
tf_http_archive(
name = "cython",
+ build_file = clean_dep("//third_party:cython.BUILD"),
+ delete = ["BUILD.bazel"],
sha256 = "bccc9aa050ea02595b2440188813b936eaf345e85fb9692790cecfe095cf91aa",
+ strip_prefix = "cython-0.28.4",
+ system_build_file = clean_dep("//third_party/systemlibs:cython.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/cython/cython/archive/0.28.4.tar.gz",
"https://github.com/cython/cython/archive/0.28.4.tar.gz",
],
- strip_prefix = "cython-0.28.4",
- build_file = clean_dep("//third_party:cython.BUILD"),
- delete = ["BUILD.bazel"],
- system_build_file = clean_dep("//third_party/systemlibs:cython.BUILD"),
)
tf_http_archive(
name = "bazel_toolchains",
+ sha256 = "3b604699685c5c65dd3f6f17425570a4b2f00ddba2f750db15acc72e55bb098b",
+ strip_prefix = "bazel-toolchains-37acf1841ab1475c98a152cb9e446460c8ae29e1",
urls = [
"https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/37acf1841ab1475c98a152cb9e446460c8ae29e1.tar.gz",
"https://github.com/bazelbuild/bazel-toolchains/archive/37acf1841ab1475c98a152cb9e446460c8ae29e1.tar.gz",
],
- strip_prefix = "bazel-toolchains-37acf1841ab1475c98a152cb9e446460c8ae29e1",
- sha256 = "3b604699685c5c65dd3f6f17425570a4b2f00ddba2f750db15acc72e55bb098b",
)
tf_http_archive(
name = "arm_neon_2_x86_sse",
+ build_file = clean_dep("//third_party:arm_neon_2_x86_sse.BUILD"),
sha256 = "c8d90aa4357f8079d427e87a6f4c493da1fa4140aee926c05902d7ec1533d9a5",
strip_prefix = "ARM_NEON_2_x86_SSE-0f77d9d182265259b135dad949230ecbf1a2633d",
urls = [
"https://mirror.bazel.build/github.com/intel/ARM_NEON_2_x86_SSE/archive/0f77d9d182265259b135dad949230ecbf1a2633d.tar.gz",
"https://github.com/intel/ARM_NEON_2_x86_SSE/archive/0f77d9d182265259b135dad949230ecbf1a2633d.tar.gz",
],
- build_file = clean_dep("//third_party:arm_neon_2_x86_sse.BUILD"),
)
- native.new_http_archive(
+ tf_http_archive(
name = "double_conversion",
+ build_file = clean_dep("//third_party:double_conversion.BUILD"),
+ sha256 = "2f7fbffac0d98d201ad0586f686034371a6d152ca67508ab611adc2386ad30de",
+ strip_prefix = "double-conversion-3992066a95b823efc8ccc1baf82a1cfc73f6e9b8",
+ system_build_file = clean_dep("//third_party/systemlibs:double_conversion.BUILD"),
urls = [
+ "https://mirror.bazel.build/github.com/google/double-conversion/archive/3992066a95b823efc8ccc1baf82a1cfc73f6e9b8.zip",
"https://github.com/google/double-conversion/archive/3992066a95b823efc8ccc1baf82a1cfc73f6e9b8.zip",
],
- sha256 = "2f7fbffac0d98d201ad0586f686034371a6d152ca67508ab611adc2386ad30de",
- strip_prefix = "double-conversion-3992066a95b823efc8ccc1baf82a1cfc73f6e9b8",
- build_file = clean_dep("//third_party:double_conversion.BUILD"),
)
tf_http_archive(
name = "tflite_mobilenet",
+ build_file = clean_dep("//third_party:tflite_mobilenet.BUILD"),
sha256 = "23f814d1c076bdf03715dfb6cab3713aa4fbdf040fd5448c43196bd2e97a4c1b",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip",
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip",
],
- build_file = clean_dep("//third_party:tflite_mobilenet.BUILD"),
)
tf_http_archive(
name = "tflite_mobilenet_ssd",
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
sha256 = "767057f2837a46d97882734b03428e8dd640b93236052b312b2f0e45613c1cf0",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_ssd_tflite_v1.zip",
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_ssd_tflite_v1.zip",
],
- build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
)
tf_http_archive(
name = "tflite_mobilenet_ssd_quant",
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
sha256 = "a809cd290b4d6a2e8a9d5dad076e0bd695b8091974e0eed1052b480b2f21b6dc",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_0.75_quant_2018_06_29.zip",
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_0.75_quant_2018_06_29.zip",
],
- build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
)
tf_http_archive(
name = "tflite_mobilenet_ssd_quant_protobuf",
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
sha256 = "09280972c5777f1aa775ef67cb4ac5d5ed21970acd8535aeca62450ef14f0d79",
+ strip_prefix = "ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18.tar.gz",
"http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18.tar.gz",
],
- strip_prefix = "ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18",
- build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
)
tf_http_archive(
name = "tflite_conv_actions_frozen",
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
sha256 = "d947b38cba389b5e2d0bfc3ea6cc49c784e187b41a071387b3742d1acac7691e",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/conv_actions_tflite.zip",
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/conv_actions_tflite.zip",
],
- build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
)
tf_http_archive(
name = "tflite_smartreply",
+ build_file = clean_dep("//third_party:tflite_smartreply.BUILD"),
sha256 = "8980151b85a87a9c1a3bb1ed4748119e4a85abd3cb5744d83da4d4bd0fbeef7c",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip",
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip",
],
- build_file = clean_dep("//third_party:tflite_smartreply.BUILD"),
)
tf_http_archive(
name = "tflite_ovic_testdata",
+ build_file = clean_dep("//third_party:tflite_ovic_testdata.BUILD"),
sha256 = "a9a705d8d519220178e2e65d383fdb21da37fdb31d1e909b0a1acdac46479e9c",
+ strip_prefix = "ovic",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/data/ovic.zip",
"https://storage.googleapis.com/download.tensorflow.org/data/ovic.zip",
],
- build_file = clean_dep("//third_party:tflite_ovic_testdata.BUILD"),
- strip_prefix = "ovic",
)
tf_http_archive(
name = "build_bazel_rules_android",
sha256 = "cd06d15dd8bb59926e4d65f9003bfc20f9da4b2519985c27e190cddc8b7a7806",
+ strip_prefix = "rules_android-0.1.1",
urls = [
"https://mirror.bazel.build/github.com/bazelbuild/rules_android/archive/v0.1.1.zip",
"https://github.com/bazelbuild/rules_android/archive/v0.1.1.zip",
],
- strip_prefix = "rules_android-0.1.1",
)
tf_http_archive(
- name = "ngraph",
+ name = "tbb",
+ build_file = clean_dep("//third_party/ngraph:tbb.BUILD"),
+ sha256 = "724686f90bcda78f13b76f297d964008737ccd6399328143c1c0093e73ae6a13",
+ strip_prefix = "tbb-tbb_2018",
urls = [
- "https://mirror.bazel.build/github.com/NervanaSystems/ngraph/archive/v0.5.0.tar.gz",
- "https://github.com/NervanaSystems/ngraph/archive/v0.5.0.tar.gz",
+ "https://mirror.bazel.build/github.com/01org/tbb/archive/tbb_2018.zip",
+ "https://github.com/01org/tbb/archive/tbb_2018.zip",
],
- sha256 = "cb35d3d98836f615408afd18371fb13e3400711247e0d822ba7f306c45e9bb2c",
- strip_prefix = "ngraph-0.5.0",
+ )
+
+ tf_http_archive(
+ name = "ngraph",
build_file = clean_dep("//third_party/ngraph:ngraph.BUILD"),
+ sha256 = "bf9dcc88e5c66021e3aac80491a231711211540d613bf9b6bd28db3f5bb86b62",
+ strip_prefix = "ngraph-0.8.1",
+ urls = [
+ "https://mirror.bazel.build/github.com/NervanaSystems/ngraph/archive/v0.8.1.tar.gz",
+ "https://github.com/NervanaSystems/ngraph/archive/v0.8.1.tar.gz",
+ ],
)
tf_http_archive(
name = "nlohmann_json_lib",
+ build_file = clean_dep("//third_party/ngraph:nlohmann_json.BUILD"),
+ sha256 = "9f3549824af3ca7e9707a2503959886362801fb4926b869789d6929098a79e47",
+ strip_prefix = "json-3.1.1",
urls = [
"https://mirror.bazel.build/github.com/nlohmann/json/archive/v3.1.1.tar.gz",
"https://github.com/nlohmann/json/archive/v3.1.1.tar.gz",
],
- sha256 = "9f3549824af3ca7e9707a2503959886362801fb4926b869789d6929098a79e47",
- strip_prefix = "json-3.1.1",
- build_file = clean_dep("//third_party/ngraph:nlohmann_json.BUILD"),
)
tf_http_archive(
name = "ngraph_tf",
+ build_file = clean_dep("//third_party/ngraph:ngraph_tf.BUILD"),
+ sha256 = "402f84c748c113780a60f35f39aab118435285543aee4900d712b76fbf8a21ee",
+ strip_prefix = "ngraph-tf-0.6.1",
urls = [
- "https://mirror.bazel.build/github.com/NervanaSystems/ngraph-tf/archive/v0.3.0-rc1.tar.gz",
- "https://github.com/NervanaSystems/ngraph-tf/archive/v0.3.0-rc1.tar.gz",
+ "https://mirror.bazel.build/github.com/NervanaSystems/ngraph-tf/archive/v0.6.1.tar.gz",
+ "https://github.com/NervanaSystems/ngraph-tf/archive/v0.6.1.tar.gz",
],
- sha256 = "7919332cb15120101c3e05c1b969a5e029a6411581312583c8f80b6aaaa83072",
- strip_prefix = "ngraph-tf-0.3.0-rc1",
- build_file = clean_dep("//third_party/ngraph:ngraph_tf.BUILD"),
)
##############################################################################
diff --git a/third_party/cub.BUILD b/third_party/cub.BUILD
index 29159c9dad..a04347b21e 100644
--- a/third_party/cub.BUILD
+++ b/third_party/cub.BUILD
@@ -20,6 +20,7 @@ filegroup(
cc_library(
name = "cub",
hdrs = if_cuda([":cub_header_files"]),
+ include_prefix = "third_party",
deps = [
"@local_config_cuda//cuda:cuda_headers",
],
diff --git a/third_party/eigen3/BUILD b/third_party/eigen3/BUILD
index 203991b50f..f072f2545a 100644
--- a/third_party/eigen3/BUILD
+++ b/third_party/eigen3/BUILD
@@ -66,19 +66,13 @@ genrule(
outs = ["include"],
cmd = """
mkdir $@
- for f in $(locations @eigen_archive//:eigen_header_files) ; do
+ for f in $(SRCS); do
d="$${f%/*}"
d="$${d#*external/eigen_archive/}"
mkdir -p "$@/$${d}"
cp "$${f}" "$@/$${d}/"
done
-
- for f in $(locations :eigen_third_party_header_files) ; do
- d="$${f%/*}"
-
- mkdir -p "$@/$${d}"
- cp "$${f}" "$@/$${d}/"
- done
""",
+ tags = ["manual"],
)
diff --git a/third_party/flatbuffers/BUILD.bazel b/third_party/flatbuffers/BUILD.bazel
index 934c0d9650..d0be482fda 100644
--- a/third_party/flatbuffers/BUILD.bazel
+++ b/third_party/flatbuffers/BUILD.bazel
@@ -108,11 +108,14 @@ cc_binary(
"grpc/src/compiler/schema_interface.h",
"src/flatc_main.cpp",
"src/idl_gen_cpp.cpp",
+ "src/idl_gen_dart.cpp",
"src/idl_gen_general.cpp",
"src/idl_gen_go.cpp",
"src/idl_gen_grpc.cpp",
"src/idl_gen_js.cpp",
"src/idl_gen_json_schema.cpp",
+ "src/idl_gen_lobster.cpp",
+ "src/idl_gen_lua.cpp",
"src/idl_gen_php.cpp",
"src/idl_gen_python.cpp",
"src/idl_gen_text.cpp",
diff --git a/third_party/flatbuffers/workspace.bzl b/third_party/flatbuffers/workspace.bzl
index 3aeef96a72..7613767fc4 100644
--- a/third_party/flatbuffers/workspace.bzl
+++ b/third_party/flatbuffers/workspace.bzl
@@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive")
def repo():
third_party_http_archive(
name = "flatbuffers",
- strip_prefix = "flatbuffers-1.9.0",
- sha256 = "5ca5491e4260cacae30f1a5786d109230db3f3a6e5a0eb45d0d0608293d247e3",
+ strip_prefix = "flatbuffers-1f5eae5d6a135ff6811724f6c57f911d1f46bb15",
+ sha256 = "b2bb0311ca40b12ebe36671bdda350b10c7728caf0cfe2d432ea3b6e409016f3",
urls = [
- "https://mirror.bazel.build/github.com/google/flatbuffers/archive/v1.9.0.tar.gz",
- "https://github.com/google/flatbuffers/archive/v1.9.0.tar.gz",
+ "https://mirror.bazel.build/github.com/google/flatbuffers/archive/1f5eae5d6a135ff6811724f6c57f911d1f46bb15.tar.gz",
+ "https://github.com/google/flatbuffers/archive/1f5eae5d6a135ff6811724f6c57f911d1f46bb15.tar.gz",
],
build_file = "//third_party/flatbuffers:BUILD.bazel",
system_build_file = "//third_party/flatbuffers:BUILD.system",
diff --git a/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl b/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl
new file mode 100644
index 0000000000..0e175b3ef6
--- /dev/null
+++ b/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl
@@ -0,0 +1,158 @@
+major_version: "local"
+minor_version: ""
+default_target_cpu: "same_as_host"
+
+default_toolchain {
+ cpu: "k8"
+ toolchain_identifier: "local_linux"
+}
+default_toolchain {
+ cpu: "piii"
+ toolchain_identifier: "local_linux"
+}
+default_toolchain {
+ cpu: "arm"
+ toolchain_identifier: "local_linux"
+}
+default_toolchain {
+ cpu: "ppc"
+ toolchain_identifier: "local_linux"
+}
+
+toolchain {
+ abi_version: "local"
+ abi_libc_version: "local"
+ builtin_sysroot: ""
+ compiler: "compiler"
+ host_system_name: "local"
+ needsPic: true
+ supports_gold_linker: false
+ supports_incremental_linker: false
+ supports_fission: false
+ supports_interface_shared_objects: false
+ supports_normalizing_ar: false
+ supports_start_end_lib: false
+ supports_thin_archives: false
+ target_libc: "local"
+ target_cpu: "local"
+ target_system_name: "local"
+ toolchain_identifier: "local_linux"
+
+ tool_path { name: "ar" path: "/usr/bin/ar" }
+ tool_path { name: "compat-ld" path: "/usr/bin/ld" }
+ tool_path { name: "cpp" path: "/usr/bin/cpp" }
+ tool_path { name: "dwp" path: "/usr/bin/dwp" }
+ # As part of the TensorFlow release, we place some ROCm-related compilation
+ # files in @local_config_rocm//crosstool/clang/bin, and this relative
+ # path, combined with the rest of our Bazel configuration causes our
+ # compilation to use those files.
+ tool_path { name: "gcc" path: "clang/bin/crosstool_wrapper_driver_rocm" }
+ # Use "-std=c++11" for hipcc. For consistency, force both the host compiler
+ # and the device compiler to use "-std=c++11".
+ cxx_flag: "-std=c++11"
+ linker_flag: "-Wl,-no-as-needed"
+ linker_flag: "-lstdc++"
+ #linker_flag: "-B/usr/bin/"
+ linker_flag: "-B/opt/rocm/hcc/compiler/bin"
+
+%{host_compiler_includes}
+ tool_path { name: "gcov" path: "/usr/bin/gcov" }
+
+ # C(++) compiles invoke the compiler (as that is the one knowing where
+ # to find libraries), but we provide LD so other rules can invoke the linker.
+ tool_path { name: "ld" path: "/usr/bin/ld" }
+
+ tool_path { name: "nm" path: "/usr/bin/nm" }
+ tool_path { name: "objcopy" path: "/usr/bin/objcopy" }
+ objcopy_embed_flag: "-I"
+ objcopy_embed_flag: "binary"
+ tool_path { name: "objdump" path: "/usr/bin/objdump" }
+ tool_path { name: "strip" path: "/usr/bin/strip" }
+
+ # Anticipated future default.
+ unfiltered_cxx_flag: "-no-canonical-prefixes"
+
+ # Make C++ compilation deterministic. Use linkstamping instead of these
+ # compiler symbols.
+ unfiltered_cxx_flag: "-Wno-builtin-macro-redefined"
+ unfiltered_cxx_flag: "-D__DATE__=\"redacted\""
+ unfiltered_cxx_flag: "-D__TIMESTAMP__=\"redacted\""
+ unfiltered_cxx_flag: "-D__TIME__=\"redacted\""
+ unfiltered_cxx_flag: "-D__HIP_PLATFORM_HCC__"
+ # The macro EIGEN_USE_HIP is used to tell Eigen to use the HIP platform headers
+ # It needs to be always set when compiling Eigen headers
+ # (irrespective of whether the source file is being compiled via HIPCC)
+ # so adding -DEIGEN_USE_HIP as a default CXX flag here
+ unfiltered_cxx_flag: "-DEIGEN_USE_HIP"
+
+
+ # Security hardening on by default.
+ # Conservative choice; -D_FORTIFY_SOURCE=2 may be unsafe in some cases.
+ # We need to undef it before redefining it as some distributions now have
+ # it enabled by default.
+ #compiler_flag: "-U_FORTIFY_SOURCE"
+ #compiler_flag: "-D_FORTIFY_SOURCE=1"
+ #compiler_flag: "-fstack-protector"
+ #compiler_flag: "-fPIE"
+ #linker_flag: "-pie"
+ #linker_flag: "-Wl,-z,relro,-z,now"
+
+ # Enable coloring even if there's no attached terminal. Bazel removes the
+ # escape sequences if --nocolor is specified. This isn't supported by gcc
+ # on Ubuntu 14.04.
+ # compiler_flag: "-fcolor-diagnostics"
+
+ # All warnings are enabled. Maybe enable -Werror as well?
+ compiler_flag: "-Wall"
+ # Enable a few more warnings that aren't part of -Wall.
+ compiler_flag: "-Wunused-but-set-parameter"
+ # But disable some that are problematic.
+ compiler_flag: "-Wno-free-nonheap-object" # has false positives
+
+ # Keep stack frames for debugging, even in opt mode.
+ compiler_flag: "-fno-omit-frame-pointer"
+
+ # Anticipated future default.
+ linker_flag: "-no-canonical-prefixes"
+ unfiltered_cxx_flag: "-fno-canonical-system-headers"
+ # Have gcc return the exit code from ld.
+ linker_flag: "-pass-exit-codes"
+ # Stamp the binary with a unique identifier.
+ linker_flag: "-Wl,--build-id=md5"
+ linker_flag: "-Wl,--hash-style=gnu"
+ # Gold linker only? Can we enable this by default?
+ # linker_flag: "-Wl,--warn-execstack"
+ # linker_flag: "-Wl,--detect-odr-violations"
+
+ # Include directory for ROCm headers.
+%{rocm_include_path}
+
+ compilation_mode_flags {
+ mode: DBG
+ # Enable debug symbols.
+ compiler_flag: "-g"
+ }
+ compilation_mode_flags {
+ mode: OPT
+
+ # No debug symbols.
+ # Maybe we should enable https://gcc.gnu.org/wiki/DebugFission for opt or
+ # even generally? However, that can't happen here, as it requires special
+ # handling in Bazel.
+ compiler_flag: "-g0"
+
+ # Conservative choice for -O
+ # -O3 can increase binary size and even slow down the resulting binaries.
+ # Profile first and / or use FDO if you need better performance than this.
+ compiler_flag: "-O2"
+
+ # Disable assertions
+ compiler_flag: "-DNDEBUG"
+
+ # Removal of unused code and data at link time (can this increase binary size in some cases?).
+ compiler_flag: "-ffunction-sections"
+ compiler_flag: "-fdata-sections"
+ linker_flag: "-Wl,--gc-sections"
+ }
+ linking_mode_flags { mode: DYNAMIC }
+}
diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
new file mode 100755
index 0000000000..824238022b
--- /dev/null
+++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
@@ -0,0 +1,241 @@
+#!/usr/bin/env python
+"""Crosstool wrapper for compiling ROCm programs.
+
+SYNOPSIS:
+ crosstool_wrapper_driver_rocm [options passed in by cc_library()
+ or cc_binary() rule]
+
+DESCRIPTION:
+ This script is expected to be called by the cc_library() or cc_binary() bazel
+ rules. When the option "-x rocm" is present in the list of arguments passed
+ to this script, it invokes the hipcc compiler. Most arguments are passed
+ as is as a string to --compiler-options of hipcc. When "-x rocm" is not
+ present, this wrapper invokes gcc with the input arguments as is.
+"""
+
+from __future__ import print_function
+
+__author__ = 'whchung@gmail.com (Wen-Heng (Jack) Chung)'
+
+from argparse import ArgumentParser
+import os
+import subprocess
+import re
+import sys
+import pipes
+
+# Template values set by rocm_configure.bzl.
+CPU_COMPILER = ('%{cpu_compiler}')
+GCC_HOST_COMPILER_PATH = ('%{gcc_host_compiler_path}')
+
+HIPCC_PATH = '%{hipcc_path}'
+PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH)
+
+def Log(s):
+ print('gpus/crosstool: {0}'.format(s))
+
+
+def GetOptionValue(argv, option):
+ """Extract the list of values for option from the argv list.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+ option: The option whose value to extract, without the leading '-'.
+
+ Returns:
+ A list of values, either directly following the option,
+ (eg., -opt val1 val2) or values collected from multiple occurrences of
+ the option (eg., -opt val1 -opt val2).
+ """
+
+ parser = ArgumentParser()
+ parser.add_argument('-' + option, nargs='*', action='append')
+ args, _ = parser.parse_known_args(argv)
+ if not args or not vars(args)[option]:
+ return []
+ else:
+ return sum(vars(args)[option], [])
+
+
+def GetHostCompilerOptions(argv):
+ """Collect the -isystem, -iquote, and --sysroot option values from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+
+ Returns:
+ The string that can be used as the --compiler-options to hipcc.
+ """
+
+ parser = ArgumentParser()
+ parser.add_argument('-isystem', nargs='*', action='append')
+ parser.add_argument('-iquote', nargs='*', action='append')
+ parser.add_argument('--sysroot', nargs=1)
+ parser.add_argument('-g', nargs='*', action='append')
+ parser.add_argument('-fno-canonical-system-headers', action='store_true')
+
+ args, _ = parser.parse_known_args(argv)
+
+ opts = ''
+
+ if args.isystem:
+ opts += ' -isystem ' + ' -isystem '.join(sum(args.isystem, []))
+ if args.iquote:
+ opts += ' -iquote ' + ' -iquote '.join(sum(args.iquote, []))
+ if args.g:
+ opts += ' -g' + ' -g'.join(sum(args.g, []))
+ #if args.fno_canonical_system_headers:
+ # opts += ' -fno-canonical-system-headers'
+ if args.sysroot:
+ opts += ' --sysroot ' + args.sysroot[0]
+
+ return opts
+
+def GetHipccOptions(argv):
+ """Collect the -hipcc_options values from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+
+ Returns:
+ The string that can be passed directly to hipcc.
+ """
+
+ parser = ArgumentParser()
+ parser.add_argument('-hipcc_options', nargs='*', action='append')
+
+ args, _ = parser.parse_known_args(argv)
+
+ if args.hipcc_options:
+ options = _update_options(sum(args.hipcc_options, []))
+ return ' '.join(['--'+a for a in options])
+ return ''
+
+
+def InvokeHipcc(argv, log=False):
+ """Call hipcc with arguments assembled from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+ log: True if logging is requested.
+
+ Returns:
+ The return value of calling os.system('hipcc ' + args)
+ """
+
+ host_compiler_options = GetHostCompilerOptions(argv)
+ hipcc_compiler_options = GetHipccOptions(argv)
+ opt_option = GetOptionValue(argv, 'O')
+ m_options = GetOptionValue(argv, 'm')
+ m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']])
+ include_options = GetOptionValue(argv, 'I')
+ out_file = GetOptionValue(argv, 'o')
+ depfiles = GetOptionValue(argv, 'MF')
+ defines = GetOptionValue(argv, 'D')
+ defines = ''.join([' -D' + define for define in defines])
+ undefines = GetOptionValue(argv, 'U')
+ undefines = ''.join([' -U' + define for define in undefines])
+ std_options = GetOptionValue(argv, 'std')
+ hipcc_allowed_std_options = ["c++11"]
+ std_options = ''.join([' -std=' + define
+ for define in std_options if define in hipcc_allowed_std_options])
+
+ # The list of source files get passed after the -c option. I don't know of
+ # any other reliable way to just get the list of source files to be compiled.
+ src_files = GetOptionValue(argv, 'c')
+
+ if len(src_files) == 0:
+ return 1
+ if len(out_file) != 1:
+ return 1
+
+ opt = (' -O2' if (len(opt_option) > 0 and int(opt_option[0]) > 0)
+ else ' -g')
+
+ includes = (' -I ' + ' -I '.join(include_options)
+ if len(include_options) > 0
+ else '')
+
+ # Unfortunately, there are other options that have -c prefix too.
+ # So allowing only those look like C/C++ files.
+ src_files = [f for f in src_files if
+ re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)]
+ srcs = ' '.join(src_files)
+ out = ' -o ' + out_file[0]
+
+ hipccopts = ' '
+ hipccopts += ' ' + hipcc_compiler_options
+ hipccopts += undefines
+ hipccopts += defines
+ hipccopts += std_options
+ hipccopts += m_options
+
+ if depfiles:
+ # Generate the dependency file
+ depfile = depfiles[0]
+ cmd = (HIPCC_PATH + ' ' + hipccopts +
+ host_compiler_options +
+ ' ' + GCC_HOST_COMPILER_PATH +
+ ' -I .' + includes + ' ' + srcs + ' -M -o ' + depfile)
+ if log: Log(cmd)
+ exit_status = os.system(cmd)
+ if exit_status != 0:
+ return exit_status
+
+ cmd = (HIPCC_PATH + ' ' + hipccopts +
+ host_compiler_options + ' -fPIC' +
+ ' ' + GCC_HOST_COMPILER_PATH +
+ ' -I .' + opt + includes + ' -c ' + srcs + out)
+
+ # TODO(zhengxq): for some reason, 'gcc' needs this help to find 'as'.
+ # Need to investigate and fix.
+ cmd = 'PATH=' + PREFIX_DIR + ':$PATH ' + cmd
+ if log: Log(cmd)
+ return os.system(cmd)
+
+
+def main():
+ # ignore PWD env var
+ os.environ['PWD']=''
+
+ parser = ArgumentParser()
+ parser.add_argument('-x', nargs=1)
+ parser.add_argument('--rocm_log', action='store_true')
+ parser.add_argument('-pass-exit-codes', action='store_true')
+ args, leftover = parser.parse_known_args(sys.argv[1:])
+
+ if args.x and args.x[0] == 'rocm':
+ if args.rocm_log: Log('-x rocm')
+ leftover = [pipes.quote(s) for s in leftover]
+ if args.rocm_log: Log('using hipcc')
+ return InvokeHipcc(leftover, log=args.rocm_log)
+
+ # XXX use hipcc to link
+ if args.pass_exit_codes:
+ gpu_compiler_flags = [flag for flag in sys.argv[1:]
+ if not flag.startswith(('-pass-exit-codes'))]
+
+ # special handling for $ORIGIN
+ # - guard every argument with ''
+ modified_gpu_compiler_flags = []
+ for flag in gpu_compiler_flags:
+ modified_gpu_compiler_flags.append("'" + flag + "'")
+
+ if args.rocm_log: Log('Link with hipcc: %s' % (' '.join([HIPCC_PATH] + modified_gpu_compiler_flags)))
+ return subprocess.call([HIPCC_PATH] + modified_gpu_compiler_flags)
+
+ # Strip our flags before passing through to the CPU compiler for files which
+ # are not -x rocm. We can't just pass 'leftover' because it also strips -x.
+ # We not only want to pass -x to the CPU compiler, but also keep it in its
+ # relative location in the argv list (the compiler is actually sensitive to
+ # this).
+ cpu_compiler_flags = [flag for flag in sys.argv[1:]
+ if not flag.startswith(('--rocm_log'))]
+
+ # XXX: SE codes need to be built with gcc, but need this macro defined
+ cpu_compiler_flags.append("-D__HIP_PLATFORM_HCC__")
+
+ return subprocess.call([CPU_COMPILER] + cpu_compiler_flags)
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index 5648b1525a..f5fdd3a75e 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -48,6 +48,7 @@ _DEFAULT_CUDA_COMPUTE_CAPABILITIES = ["3.5", "5.2"]
CUDA_LIB_PATHS = [
"lib64/",
"lib64/stubs/",
+ "lib/powerpc64le-linux-gnu/",
"lib/x86_64-linux-gnu/",
"lib/x64/",
"lib/",
@@ -70,6 +71,7 @@ CUPTI_HEADER_PATHS = [
# the other CUDA libraries but rather in a special extras/CUPTI directory.
CUPTI_LIB_PATHS = [
"extras/CUPTI/lib64/",
+ "lib/powerpc64le-linux-gnu/",
"lib/x86_64-linux-gnu/",
"lib64/",
"extras/CUPTI/libx64/",
diff --git a/third_party/gpus/rocm/BUILD b/third_party/gpus/rocm/BUILD
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/third_party/gpus/rocm/BUILD
diff --git a/third_party/gpus/rocm/BUILD.tpl b/third_party/gpus/rocm/BUILD.tpl
new file mode 100644
index 0000000000..8258bb3589
--- /dev/null
+++ b/third_party/gpus/rocm/BUILD.tpl
@@ -0,0 +1,99 @@
+licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like
+
+package(default_visibility = ["//visibility:public"])
+
+config_setting(
+ name = "using_hipcc",
+ values = {
+ "define": "using_rocm_hipcc=true",
+ },
+)
+
+cc_library(
+ name = "rocm_headers",
+ hdrs = [
+ "rocm/rocm_config.h",
+ %{rocm_headers}
+ ],
+ includes = [
+ ".",
+ "rocm/include",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "hip",
+ srcs = ["rocm/lib/%{hip_lib}"],
+ data = ["rocm/lib/%{hip_lib}"],
+ includes = [
+ ".",
+ "rocm/include",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "rocblas",
+ srcs = ["rocm/lib/%{rocblas_lib}"],
+ data = ["rocm/lib/%{rocblas_lib}"],
+ includes = [
+ ".",
+ "rocm/include",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "rocfft",
+ srcs = ["rocm/lib/%{rocfft_lib}"],
+ data = ["rocm/lib/%{rocfft_lib}"],
+ includes = [
+ ".",
+ "rocm/include",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "hiprand",
+ srcs = ["rocm/lib/%{hiprand_lib}"],
+ data = ["rocm/lib/%{hiprand_lib}"],
+ includes = [
+ ".",
+ "rocm/include",
+ "rocm/include/rocrand",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "miopen",
+ srcs = ["rocm/lib/%{miopen_lib}"],
+ data = ["rocm/lib/%{miopen_lib}"],
+ includes = [
+ ".",
+ "rocm/include",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "rocm",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":rocm_headers",
+ ":hip",
+ ":rocblas",
+ ":rocfft",
+ ":hiprand",
+ ":miopen",
+ ],
+)
+
+%{rocm_include_genrules}
diff --git a/third_party/gpus/rocm/build_defs.bzl.tpl b/third_party/gpus/rocm/build_defs.bzl.tpl
new file mode 100644
index 0000000000..08c59f95a0
--- /dev/null
+++ b/third_party/gpus/rocm/build_defs.bzl.tpl
@@ -0,0 +1,45 @@
+# Macros for building ROCm code.
+def if_rocm(if_true, if_false = []):
+ """Shorthand for select()'ing on whether we're building with ROCm.
+
+ Returns a select statement which evaluates to if_true if we're building
+ with ROCm enabled. Otherwise, the select statement evaluates to if_false.
+
+ """
+ return select({
+ "@local_config_rocm//rocm:using_hipcc": if_true,
+ "//conditions:default": if_false
+ })
+
+
+def rocm_default_copts():
+ """Default options for all ROCm compilations."""
+ return if_rocm(["-x", "rocm"] + %{rocm_extra_copts})
+
+def rocm_copts(opts = []):
+ """Gets the appropriate set of copts for (maybe) ROCm compilation.
+
+ If we're doing ROCm compilation, returns copts for our particular ROCm
+ compiler. If we're not doing ROCm compilation, returns an empty list.
+
+ """
+ return rocm_default_copts() + select({
+ "//conditions:default": [],
+ "@local_config_rocm//rocm:using_hipcc": ([
+ "",
+ ]),
+ }) + if_rocm_is_configured(opts)
+
+def rocm_is_configured():
+ """Returns true if ROCm was enabled during the configure process."""
+ return %{rocm_is_configured}
+
+def if_rocm_is_configured(x):
+ """Tests if the ROCm was enabled during the configure process.
+
+ Unlike if_rocm(), this does not require that we are building with
+ --config=rocm. Used to allow non-ROCm code to depend on ROCm libraries.
+ """
+ if rocm_is_configured():
+ return x
+ return []
diff --git a/third_party/gpus/rocm/rocm_config.h.tpl b/third_party/gpus/rocm/rocm_config.h.tpl
new file mode 100644
index 0000000000..c5f25a845c
--- /dev/null
+++ b/third_party/gpus/rocm/rocm_config.h.tpl
@@ -0,0 +1,21 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef ROCM_ROCM_CONFIG_H_
+#define ROCM_ROCM_CONFIG_H_
+
+#define TF_ROCM_TOOLKIT_PATH "/opt/rocm"
+
+#endif // ROCM_ROCM_CONFIG_H_
diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl
new file mode 100644
index 0000000000..9108639b0b
--- /dev/null
+++ b/third_party/gpus/rocm_configure.bzl
@@ -0,0 +1,784 @@
+# -*- Python -*-
+"""Repository rule for ROCm autoconfiguration.
+
+`rocm_configure` depends on the following environment variables:
+
+ * `TF_NEED_ROCM`: Whether to enable building with ROCm.
+ * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path
+ * `ROCM_TOOLKIT_PATH`: The path to the ROCm toolkit. Default is
+ `/opt/rocm`.
+ * `TF_ROCM_VERSION`: The version of the ROCm toolkit. If this is blank, then
+ use the system default.
+ * `TF_MIOPEN_VERSION`: The version of the MIOpen library.
+ * `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets. Default is
+ `gfx803,gfx900`.
+"""
+
+_GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH"
+_ROCM_TOOLKIT_PATH = "ROCM_TOOLKIT_PATH"
+_TF_ROCM_VERSION = "TF_ROCM_VERSION"
+_TF_MIOPEN_VERSION = "TF_MIOPEN_VERSION"
+_TF_ROCM_AMDGPU_TARGETS = "TF_ROCM_AMDGPU_TARGETS"
+_TF_ROCM_CONFIG_REPO = "TF_ROCM_CONFIG_REPO"
+
+_DEFAULT_ROCM_VERSION = ""
+_DEFAULT_MIOPEN_VERSION = ""
+_DEFAULT_ROCM_TOOLKIT_PATH = "/opt/rocm"
+_DEFAULT_ROCM_AMDGPU_TARGETS = ["gfx803", "gfx900"]
+
+def find_cc(repository_ctx):
+ """Find the C++ compiler."""
+
+ # Return a dummy value for GCC detection here to avoid error
+ target_cc_name = "gcc"
+ cc_path_envvar = _GCC_HOST_COMPILER_PATH
+ cc_name = target_cc_name
+
+ if cc_path_envvar in repository_ctx.os.environ:
+ cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip()
+ if cc_name_from_env:
+ cc_name = cc_name_from_env
+ if cc_name.startswith("/"):
+ # Absolute path, maybe we should make this supported by our which function.
+ return cc_name
+ cc = repository_ctx.which(cc_name)
+ if cc == None:
+ fail(("Cannot find {}, either correct your path or set the {}" +
+ " environment variable").format(target_cc_name, cc_path_envvar))
+ return cc
+
+_INC_DIR_MARKER_BEGIN = "#include <...>"
+
+def _cxx_inc_convert(path):
+ """Convert path returned by cc -E xc++ in a complete path."""
+ path = path.strip()
+ return path
+
+def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp):
+ """Compute the list of default C or C++ include directories."""
+ if lang_is_cpp:
+ lang = "c++"
+ else:
+ lang = "c"
+
+ # TODO: We pass -no-canonical-prefixes here to match the compiler flags,
+ # but in rocm_clang CROSSTOOL file that is a `feature` and we should
+ # handle the case when it's disabled and no flag is passed
+ result = repository_ctx.execute([
+ cc,
+ "-no-canonical-prefixes",
+ "-E",
+ "-x" + lang,
+ "-",
+ "-v",
+ ])
+ index1 = result.stderr.find(_INC_DIR_MARKER_BEGIN)
+ if index1 == -1:
+ return []
+ index1 = result.stderr.find("\n", index1)
+ if index1 == -1:
+ return []
+ index2 = result.stderr.rfind("\n ")
+ if index2 == -1 or index2 < index1:
+ return []
+ index2 = result.stderr.find("\n", index2 + 1)
+ if index2 == -1:
+ inc_dirs = result.stderr[index1 + 1:]
+ else:
+ inc_dirs = result.stderr[index1 + 1:index2].strip()
+
+ return [
+ str(repository_ctx.path(_cxx_inc_convert(p)))
+ for p in inc_dirs.split("\n")
+ ]
+
+def get_cxx_inc_directories(repository_ctx, cc):
+ """Compute the list of default C and C++ include directories."""
+
+ # For some reason `clang -xc` sometimes returns include paths that are
+ # different from the ones from `clang -xc++`. (Symlink and a dir)
+ # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists
+ includes_cpp = _get_cxx_inc_directories_impl(repository_ctx, cc, True)
+ includes_c = _get_cxx_inc_directories_impl(repository_ctx, cc, False)
+
+ includes_cpp_set = depset(includes_cpp)
+ return includes_cpp + [
+ inc
+ for inc in includes_c
+ if inc not in includes_cpp_set
+ ]
+
+def auto_configure_fail(msg):
+ """Output failure message when rocm configuration fails."""
+ red = "\033[0;31m"
+ no_color = "\033[0m"
+ fail("\n%sROCm Configuration Error:%s %s\n" % (red, no_color, msg))
+
+# END cc_configure common functions (see TODO above).
+
+def _host_compiler_includes(repository_ctx, cc):
+ """Generates the cxx_builtin_include_directory entries for gcc inc dirs.
+
+ Args:
+ repository_ctx: The repository context.
+ cc: The path to the gcc host compiler.
+
+ Returns:
+ A string containing the cxx_builtin_include_directory for each of the gcc
+ host compiler include directories, which can be added to the CROSSTOOL
+ file.
+ """
+ inc_dirs = get_cxx_inc_directories(repository_ctx, cc)
+
+ # Add numpy headers
+ inc_dirs.append("/usr/lib/python2.7/dist-packages/numpy/core/include")
+
+ entries = []
+ for inc_dir in inc_dirs:
+ entries.append(" cxx_builtin_include_directory: \"%s\"" % inc_dir)
+
+ # define TENSORFLOW_USE_ROCM
+ entries.append(" unfiltered_cxx_flag: \"-DTENSORFLOW_USE_ROCM\"")
+
+ return "\n".join(entries)
+
+def _rocm_include_path(repository_ctx, rocm_config):
+ """Generates the cxx_builtin_include_directory entries for rocm inc dirs.
+
+ Args:
+ repository_ctx: The repository context.
+ cc: The path to the gcc host compiler.
+
+ Returns:
+ A string containing the cxx_builtin_include_directory for each of the gcc
+ host compiler include directories, which can be added to the CROSSTOOL
+ file.
+ """
+ inc_dirs = []
+
+ # general ROCm include path
+ inc_dirs.append(rocm_config.rocm_toolkit_path + "/include")
+
+ # Add HSA headers
+ inc_dirs.append("/opt/rocm/hsa/include")
+
+ # Add HIP headers
+ inc_dirs.append("/opt/rocm/include/hip")
+ inc_dirs.append("/opt/rocm/include/hip/hcc_detail")
+
+ # Add rocrand and hiprand headers
+ inc_dirs.append("/opt/rocm/rocrand/include")
+ inc_dirs.append("/opt/rocm/hiprand/include")
+
+ # Add rocfft headers
+ inc_dirs.append("/opt/rocm/rocfft/include")
+
+ # Add rocBLAS headers
+ inc_dirs.append("/opt/rocm/rocblas/include")
+
+ # Add MIOpen headers
+ inc_dirs.append("/opt/rocm/miopen/include")
+
+ # Add hcc headers
+ inc_dirs.append("/opt/rocm/hcc/include")
+ inc_dirs.append("/opt/rocm/hcc/compiler/lib/clang/7.0.0/include/")
+ inc_dirs.append("/opt/rocm/hcc/lib/clang/7.0.0/include")
+
+ # Newer hcc builds use/are based off of clang 8.0.0.
+ inc_dirs.append("/opt/rocm/hcc/compiler/lib/clang/8.0.0/include/")
+ inc_dirs.append("/opt/rocm/hcc/lib/clang/8.0.0/include")
+
+ inc_entries = []
+ for inc_dir in inc_dirs:
+ inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % inc_dir)
+ return "\n".join(inc_entries)
+
+def _enable_rocm(repository_ctx):
+ if "TF_NEED_ROCM" in repository_ctx.os.environ:
+ enable_rocm = repository_ctx.os.environ["TF_NEED_ROCM"].strip()
+ return enable_rocm == "1"
+ return False
+
+def _rocm_toolkit_path(repository_ctx):
+ """Finds the rocm toolkit directory.
+
+ Args:
+ repository_ctx: The repository context.
+
+ Returns:
+ A speculative real path of the rocm toolkit install directory.
+ """
+ rocm_toolkit_path = _DEFAULT_ROCM_TOOLKIT_PATH
+ if _ROCM_TOOLKIT_PATH in repository_ctx.os.environ:
+ rocm_toolkit_path = repository_ctx.os.environ[_ROCM_TOOLKIT_PATH].strip()
+ if not repository_ctx.path(rocm_toolkit_path).exists:
+ auto_configure_fail("Cannot find rocm toolkit path.")
+ return str(repository_ctx.path(rocm_toolkit_path).realpath)
+
+def _amdgpu_targets(repository_ctx):
+ """Returns a list of strings representing AMDGPU targets."""
+ if _TF_ROCM_AMDGPU_TARGETS not in repository_ctx.os.environ:
+ return _DEFAULT_ROCM_AMDGPU_TARGETS
+ amdgpu_targets_str = repository_ctx.os.environ[_TF_ROCM_AMDGPU_TARGETS]
+ amdgpu_targets = amdgpu_targets_str.split(",")
+ for amdgpu_target in amdgpu_targets:
+ if amdgpu_target[:3] != "gfx" or not amdgpu_target[3:].isdigit():
+ auto_configure_fail("Invalid AMDGPU target: %s" % amdgpu_target)
+ return amdgpu_targets
+
+def _cpu_value(repository_ctx):
+ """Returns the name of the host operating system.
+
+ Args:
+ repository_ctx: The repository context.
+
+ Returns:
+ A string containing the name of the host operating system.
+ """
+ os_name = repository_ctx.os.name.lower()
+ if os_name.startswith("mac os"):
+ return "Darwin"
+ if os_name.find("windows") != -1:
+ return "Windows"
+ result = repository_ctx.execute(["uname", "-s"])
+ return result.stdout.strip()
+
+def _lib_name(lib, cpu_value, version = "", static = False):
+ """Constructs the platform-specific name of a library.
+
+ Args:
+ lib: The name of the library, such as "hip"
+ cpu_value: The name of the host operating system.
+ version: The version of the library.
+ static: True the library is static or False if it is a shared object.
+
+ Returns:
+ The platform-specific name of the library.
+ """
+ if cpu_value in ("Linux"):
+ if static:
+ return "lib%s.a" % lib
+ else:
+ if version:
+ version = ".%s" % version
+ return "lib%s.so%s" % (lib, version)
+ elif cpu_value == "Windows":
+ return "%s.lib" % lib
+ elif cpu_value == "Darwin":
+ if static:
+ return "lib%s.a" % lib
+ elif version:
+ version = ".%s" % version
+ return "lib%s%s.dylib" % (lib, version)
+ else:
+ auto_configure_fail("Invalid cpu_value: %s" % cpu_value)
+
+def _find_rocm_lib(
+ lib,
+ repository_ctx,
+ cpu_value,
+ basedir,
+ version = "",
+ static = False):
+ """Finds the given ROCm libraries on the system.
+
+ Args:
+ lib: The name of the library, such as "hip"
+ repository_ctx: The repository context.
+ cpu_value: The name of the host operating system.
+ basedir: The install directory of ROCm.
+ version: The version of the library.
+ static: True if static library, False if shared object.
+
+ Returns:
+ Returns a struct with the following fields:
+ file_name: The basename of the library found on the system.
+ path: The full path to the library.
+ """
+ file_name = _lib_name(lib, cpu_value, version, static)
+ if cpu_value == "Linux":
+ path = repository_ctx.path("%s/lib64/%s" % (basedir, file_name))
+ if path.exists:
+ return struct(file_name = file_name, path = str(path.realpath))
+ path = repository_ctx.path("%s/lib64/stubs/%s" % (basedir, file_name))
+ if path.exists:
+ return struct(file_name = file_name, path = str(path.realpath))
+ path = repository_ctx.path(
+ "%s/lib/x86_64-linux-gnu/%s" % (basedir, file_name),
+ )
+ if path.exists:
+ return struct(file_name = file_name, path = str(path.realpath))
+
+ path = repository_ctx.path("%s/lib/%s" % (basedir, file_name))
+ if path.exists:
+ return struct(file_name = file_name, path = str(path.realpath))
+ path = repository_ctx.path("%s/%s" % (basedir, file_name))
+ if path.exists:
+ return struct(file_name = file_name, path = str(path.realpath))
+
+ auto_configure_fail("Cannot find rocm library %s" % file_name)
+
+def _find_libs(repository_ctx, rocm_config):
+ """Returns the ROCm libraries on the system.
+
+ Args:
+ repository_ctx: The repository context.
+ rocm_config: The ROCm config as returned by _get_rocm_config
+
+ Returns:
+ Map of library names to structs of filename and path as returned by
+ _find_rocm_lib.
+ """
+ cpu_value = rocm_config.cpu_value
+ return {
+ "hip": _find_rocm_lib(
+ "hip_hcc",
+ repository_ctx,
+ cpu_value,
+ rocm_config.rocm_toolkit_path,
+ ),
+ "rocblas": _find_rocm_lib(
+ "rocblas",
+ repository_ctx,
+ cpu_value,
+ rocm_config.rocm_toolkit_path + "/rocblas",
+ ),
+ "rocfft": _find_rocm_lib(
+ "rocfft",
+ repository_ctx,
+ cpu_value,
+ rocm_config.rocm_toolkit_path + "/rocfft",
+ ),
+ "hiprand": _find_rocm_lib(
+ "hiprand",
+ repository_ctx,
+ cpu_value,
+ rocm_config.rocm_toolkit_path + "/hiprand",
+ ),
+ "miopen": _find_rocm_lib(
+ "MIOpen",
+ repository_ctx,
+ cpu_value,
+ rocm_config.rocm_toolkit_path + "/miopen",
+ ),
+ }
+
+def _get_rocm_config(repository_ctx):
+ """Detects and returns information about the ROCm installation on the system.
+
+ Args:
+ repository_ctx: The repository context.
+
+ Returns:
+ A struct containing the following fields:
+ rocm_toolkit_path: The ROCm toolkit installation directory.
+ amdgpu_targets: A list of the system's AMDGPU targets.
+ cpu_value: The name of the host operating system.
+ """
+ cpu_value = _cpu_value(repository_ctx)
+ rocm_toolkit_path = _rocm_toolkit_path(repository_ctx)
+ return struct(
+ rocm_toolkit_path = rocm_toolkit_path,
+ amdgpu_targets = _amdgpu_targets(repository_ctx),
+ cpu_value = cpu_value,
+ )
+
+def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
+ if not out:
+ out = tpl.replace(":", "/")
+ repository_ctx.template(
+ out,
+ Label("//third_party/gpus/%s.tpl" % tpl),
+ substitutions,
+ )
+
+def _file(repository_ctx, label):
+ repository_ctx.template(
+ label.replace(":", "/"),
+ Label("//third_party/gpus/%s.tpl" % label),
+ {},
+ )
+
+_DUMMY_CROSSTOOL_BZL_FILE = """
+def error_gpu_disabled():
+ fail("ERROR: Building with --config=rocm but TensorFlow is not configured " +
+ "to build with GPU support. Please re-run ./configure and enter 'Y' " +
+ "at the prompt to build with GPU support.")
+
+ native.genrule(
+ name = "error_gen_crosstool",
+ outs = ["CROSSTOOL"],
+ cmd = "echo 'Should not be run.' && exit 1",
+ )
+
+ native.filegroup(
+ name = "crosstool",
+ srcs = [":CROSSTOOL"],
+ output_licenses = ["unencumbered"],
+ )
+"""
+
+_DUMMY_CROSSTOOL_BUILD_FILE = """
+load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled")
+
+error_gpu_disabled()
+"""
+
+def _create_dummy_repository(repository_ctx):
+ cpu_value = _cpu_value(repository_ctx)
+
+ # Set up BUILD file for rocm/.
+ _tpl(
+ repository_ctx,
+ "rocm:build_defs.bzl",
+ {
+ "%{rocm_is_configured}": "False",
+ "%{rocm_extra_copts}": "[]",
+ },
+ )
+ _tpl(
+ repository_ctx,
+ "rocm:BUILD",
+ {
+ "%{hip_lib}": _lib_name("hip", cpu_value),
+ "%{rocblas_lib}": _lib_name("rocblas", cpu_value),
+ "%{miopen_lib}": _lib_name("miopen", cpu_value),
+ "%{rocfft_lib}": _lib_name("rocfft", cpu_value),
+ "%{hiprand_lib}": _lib_name("hiprand", cpu_value),
+ "%{rocm_include_genrules}": "",
+ "%{rocm_headers}": "",
+ },
+ )
+
+ # Create dummy files for the ROCm toolkit since they are still required by
+ # tensorflow/core/platform/default/build_config:rocm.
+ repository_ctx.file("rocm/hip/include/hip/hip_runtime.h", "")
+
+ # Set up rocm_config.h, which is used by
+ # tensorflow/stream_executor/dso_loader.cc.
+ _tpl(
+ repository_ctx,
+ "rocm:rocm_config.h",
+ {
+ "%{rocm_toolkit_path}": _DEFAULT_ROCM_TOOLKIT_PATH,
+ },
+ "rocm/rocm/rocm_config.h",
+ )
+
+ # If rocm_configure is not configured to build with GPU support, and the user
+ # attempts to build with --config=rocm, add a dummy build rule to intercept
+ # this and fail with an actionable error message.
+ repository_ctx.file(
+ "crosstool/error_gpu_disabled.bzl",
+ _DUMMY_CROSSTOOL_BZL_FILE,
+ )
+ repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
+
+def _execute(
+ repository_ctx,
+ cmdline,
+ error_msg = None,
+ error_details = None,
+ empty_stdout_fine = False):
+ """Executes an arbitrary shell command.
+
+ Args:
+ repository_ctx: the repository_ctx object
+ cmdline: list of strings, the command to execute
+ error_msg: string, a summary of the error if the command fails
+ error_details: string, details about the error or steps to fix it
+ empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise
+ it's an error
+ Return:
+ the result of repository_ctx.execute(cmdline)
+ """
+ result = repository_ctx.execute(cmdline)
+ if result.stderr or not (empty_stdout_fine or result.stdout):
+ auto_configure_fail(
+ "\n".join([
+ error_msg.strip() if error_msg else "Repository command failed",
+ result.stderr.strip(),
+ error_details if error_details else "",
+ ]),
+ )
+ return result
+
+def _norm_path(path):
+ """Returns a path with '/' and remove the trailing slash."""
+ path = path.replace("\\", "/")
+ if path[-1] == "/":
+ path = path[:-1]
+ return path
+
+def _symlink_genrule_for_dir(
+ repository_ctx,
+ src_dir,
+ dest_dir,
+ genrule_name,
+ src_files = [],
+ dest_files = []):
+ """Returns a genrule to symlink(or copy if on Windows) a set of files.
+
+ If src_dir is passed, files will be read from the given directory; otherwise
+ we assume files are in src_files and dest_files
+ """
+ if src_dir != None:
+ src_dir = _norm_path(src_dir)
+ dest_dir = _norm_path(dest_dir)
+ files = _read_dir(repository_ctx, src_dir)
+
+ # Create a list with the src_dir stripped to use for outputs.
+ dest_files = files.replace(src_dir, "").splitlines()
+ src_files = files.splitlines()
+ command = []
+
+ # We clear folders that might have been generated previously to avoid
+ # undesired inclusions
+ command.append('if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi')
+ command.append('if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi')
+ outs = []
+ for i in range(len(dest_files)):
+ if dest_files[i] != "":
+ # If we have only one file to link we do not want to use the dest_dir, as
+ # $(@D) will include the full path to the file.
+ dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i]
+
+ # On Windows, symlink is not supported, so we just copy all the files.
+ cmd = "ln -s"
+ command.append(cmd + ' "%s" "%s"' % (src_files[i], dest))
+ outs.append(' "' + dest_dir + dest_files[i] + '",')
+ genrule = _genrule(
+ src_dir,
+ genrule_name,
+ " && ".join(command),
+ "\n".join(outs),
+ )
+ return genrule
+
+def _genrule(src_dir, genrule_name, command, outs):
+ """Returns a string with a genrule.
+
+ Genrule executes the given command and produces the given outputs.
+ """
+ return (
+ "genrule(\n" +
+ ' name = "' +
+ genrule_name + '",\n' +
+ " outs = [\n" +
+ outs +
+ "\n ],\n" +
+ ' cmd = """\n' +
+ command +
+ '\n """,\n' +
+ ")\n"
+ )
+
+def _read_dir(repository_ctx, src_dir):
+ """Returns a string with all files in a directory.
+
+ Finds all files inside a directory, traversing subfolders and following
+ symlinks. The returned string contains the full path of all files
+ separated by line breaks.
+ """
+ find_result = _execute(
+ repository_ctx,
+ ["find", src_dir, "-follow", "-type", "f"],
+ empty_stdout_fine = True,
+ )
+ result = find_result.stdout
+ return result
+
+def _compute_rocm_extra_copts(repository_ctx, amdgpu_targets):
+ if False:
+ amdgpu_target_flags = ["--amdgpu-target=" +
+ amdgpu_target for amdgpu_target in amdgpu_targets]
+ else:
+ # AMDGPU targets are handled in the "crosstool_wrapper_driver_is_not_gcc"
+ amdgpu_target_flags = []
+ return str(amdgpu_target_flags)
+
+def _create_local_rocm_repository(repository_ctx):
+ """Creates the repository containing files set up to build with ROCm."""
+ rocm_config = _get_rocm_config(repository_ctx)
+
+ # Set up symbolic links for the rocm toolkit by creating genrules to do
+ # symlinking. We create one genrule for each directory we want to track under
+ # rocm_toolkit_path
+ rocm_toolkit_path = rocm_config.rocm_toolkit_path
+ rocm_include_path = rocm_toolkit_path + "/include"
+ genrules = [_symlink_genrule_for_dir(
+ repository_ctx,
+ rocm_include_path,
+ "rocm/include",
+ "rocm-include",
+ )]
+ genrules.append(_symlink_genrule_for_dir(
+ repository_ctx,
+ rocm_toolkit_path + "/rocfft/include",
+ "rocm/include/rocfft",
+ "rocfft-include",
+ ))
+ genrules.append(_symlink_genrule_for_dir(
+ repository_ctx,
+ rocm_toolkit_path + "/rocblas/include",
+ "rocm/include/rocblas",
+ "rocblas-include",
+ ))
+ genrules.append(_symlink_genrule_for_dir(
+ repository_ctx,
+ rocm_toolkit_path + "/miopen/include",
+ "rocm/include/miopen",
+ "miopen-include",
+ ))
+
+ rocm_libs = _find_libs(repository_ctx, rocm_config)
+ rocm_lib_src = []
+ rocm_lib_dest = []
+ for lib in rocm_libs.values():
+ rocm_lib_src.append(lib.path)
+ rocm_lib_dest.append("rocm/lib/" + lib.file_name)
+ genrules.append(_symlink_genrule_for_dir(
+ repository_ctx,
+ None,
+ "",
+ "rocm-lib",
+ rocm_lib_src,
+ rocm_lib_dest,
+ ))
+
+ included_files = _read_dir(repository_ctx, rocm_include_path).replace(
+ rocm_include_path,
+ "",
+ ).splitlines()
+
+ # Set up BUILD file for rocm/
+ _tpl(
+ repository_ctx,
+ "rocm:build_defs.bzl",
+ {
+ "%{rocm_is_configured}": "True",
+ "%{rocm_extra_copts}": _compute_rocm_extra_copts(
+ repository_ctx,
+ rocm_config.amdgpu_targets,
+ ),
+ },
+ )
+ _tpl(
+ repository_ctx,
+ "rocm:BUILD",
+ {
+ "%{hip_lib}": rocm_libs["hip"].file_name,
+ "%{rocblas_lib}": rocm_libs["rocblas"].file_name,
+ "%{rocfft_lib}": rocm_libs["rocfft"].file_name,
+ "%{hiprand_lib}": rocm_libs["hiprand"].file_name,
+ "%{miopen_lib}": rocm_libs["miopen"].file_name,
+ "%{rocm_include_genrules}": "\n".join(genrules),
+ "%{rocm_headers}": ('":rocm-include",\n' +
+ '":rocfft-include",\n' +
+ '":rocblas-include",\n' +
+ '":miopen-include",'),
+ },
+ )
+
+ # Set up crosstool/
+ _tpl(repository_ctx, "crosstool:BUILD", {"%{linker_files}": ":empty", "%{win_linker_files}": ":empty"})
+ cc = find_cc(repository_ctx)
+ host_compiler_includes = _host_compiler_includes(repository_ctx, cc)
+ rocm_defines = {
+ "%{rocm_include_path}": _rocm_include_path(
+ repository_ctx,
+ rocm_config,
+ ),
+ "%{host_compiler_includes}": host_compiler_includes,
+ "%{clang_path}": str(cc),
+ }
+
+ _tpl(repository_ctx, "crosstool:CROSSTOOL_hipcc", rocm_defines, out = "crosstool/CROSSTOOL")
+
+ _tpl(
+ repository_ctx,
+ "crosstool:clang/bin/crosstool_wrapper_driver_rocm",
+ {
+ "%{cpu_compiler}": str(cc),
+ "%{hipcc_path}": "/opt/rocm/bin/hipcc",
+ "%{gcc_host_compiler_path}": str(cc),
+ "%{rocm_amdgpu_targets}": ",".join(
+ ["\"%s\"" % c for c in rocm_config.amdgpu_targets],
+ ),
+ },
+ )
+
+ # Set up rocm_config.h, which is used by
+ # tensorflow/stream_executor/dso_loader.cc.
+ _tpl(
+ repository_ctx,
+ "rocm:rocm_config.h",
+ {
+ "%{rocm_amdgpu_targets}": ",".join(
+ ["\"%s\"" % c for c in rocm_config.amdgpu_targets],
+ ),
+ "%{rocm_toolkit_path}": rocm_config.rocm_toolkit_path,
+ },
+ "rocm/rocm/rocm_config.h",
+ )
+
+def _create_remote_rocm_repository(repository_ctx, remote_config_repo):
+ """Creates pointers to a remotely configured repo set up to build with ROCm."""
+ _tpl(
+ repository_ctx,
+ "rocm:build_defs.bzl",
+ {
+ "%{rocm_is_configured}": "True",
+ "%{rocm_extra_copts}": _compute_rocm_extra_copts(
+ repository_ctx, #_compute_capabilities(repository_ctx)
+ ),
+ },
+ )
+ _tpl(
+ repository_ctx,
+ "rocm:remote.BUILD",
+ {
+ "%{remote_rocm_repo}": remote_config_repo,
+ },
+ "rocm/BUILD",
+ )
+ _tpl(repository_ctx, "crosstool:remote.BUILD", {
+ "%{remote_rocm_repo}": remote_config_repo,
+ }, "crosstool/BUILD")
+
+def _rocm_autoconf_impl(repository_ctx):
+ """Implementation of the rocm_autoconf repository rule."""
+ if not _enable_rocm(repository_ctx):
+ _create_dummy_repository(repository_ctx)
+ elif _TF_ROCM_CONFIG_REPO in repository_ctx.os.environ:
+ _create_remote_rocm_repository(
+ repository_ctx,
+ repository_ctx.os.environ[_TF_ROCM_CONFIG_REPO],
+ )
+ else:
+ _create_local_rocm_repository(repository_ctx)
+
+rocm_configure = repository_rule(
+ implementation = _rocm_autoconf_impl,
+ environ = [
+ _GCC_HOST_COMPILER_PATH,
+ "TF_NEED_ROCM",
+ _ROCM_TOOLKIT_PATH,
+ _TF_ROCM_VERSION,
+ _TF_MIOPEN_VERSION,
+ _TF_ROCM_AMDGPU_TARGETS,
+ _TF_ROCM_CONFIG_REPO,
+ ],
+)
+
+"""Detects and configures the local ROCm toolchain.
+
+Add the following to your WORKSPACE FILE:
+
+```python
+rocm_configure(name = "local_config_rocm")
+```
+
+Args:
+ name: A unique name for this workspace rule.
+"""
diff --git a/third_party/icu/BUILD b/third_party/icu/BUILD
new file mode 100644
index 0000000000..82bab3ffd9
--- /dev/null
+++ b/third_party/icu/BUILD
@@ -0,0 +1 @@
+# This empty BUILD file is required to make Bazel treat this directory as a package.
diff --git a/third_party/icu/BUILD.bazel b/third_party/icu/BUILD.bazel
new file mode 100644
index 0000000000..36d6b9006b
--- /dev/null
+++ b/third_party/icu/BUILD.bazel
@@ -0,0 +1,88 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files([
+ "icu4c/LICENSE",
+ "icu4j/main/shared/licenses/LICENSE",
+])
+
+cc_library(
+ name = "headers",
+ hdrs = glob(["icu4c/source/common/unicode/*.h"]),
+ includes = [
+ "icu4c/source/common",
+ ],
+ deps = [
+ ],
+)
+
+cc_library(
+ name = "common",
+ hdrs = glob(["icu4c/source/common/unicode/*.h"]),
+ includes = [
+ "icu4c/source/common",
+ ],
+ deps = [
+ ":icuuc",
+ ],
+)
+
+cc_library(
+ name = "icuuc",
+ srcs = glob(
+ [
+ "icu4c/source/common/*.c",
+ "icu4c/source/common/*.cpp",
+ "icu4c/source/stubdata/*.cpp",
+ ],
+ ),
+ hdrs = glob([
+ "icu4c/source/common/*.h",
+ ]),
+ copts = [
+ "-DU_COMMON_IMPLEMENTATION",
+ "-DU_HAVE_STD_ATOMICS",
+ ] + select({
+ ":android": [
+ "-fdata-sections",
+ "-DGOOGLE_VENDOR_SRC_BRANCH",
+ "-DU_HAVE_NL_LANGINFO_CODESET=0",
+ "-Wno-deprecated-declarations",
+ ],
+ ":apple": [
+ "-DGOOGLE_VENDOR_SRC_BRANCH",
+ "-Wno-shorten-64-to-32",
+ "-Wno-unused-variable",
+ ],
+ ":windows": [
+ "/utf-8",
+ "/DLOCALE_ALLOW_NEUTRAL_NAMES=0",
+ ],
+ "//conditions:default": [],
+ }),
+ tags = ["requires-rtti"],
+ visibility = [
+ "//visibility:private",
+ ],
+ deps = [
+ ":headers",
+ ],
+)
+
+config_setting(
+ name = "android",
+ values = {"crosstool_top": "//external:android/crosstool"},
+)
+
+config_setting(
+ name = "apple",
+ values = {"cpu": "darwin"},
+)
+
+config_setting(
+ name = "windows",
+ values = {"cpu": "x64_windows"},
+)
diff --git a/third_party/icu/workspace.bzl b/third_party/icu/workspace.bzl
new file mode 100644
index 0000000000..bfebf4219b
--- /dev/null
+++ b/third_party/icu/workspace.bzl
@@ -0,0 +1,15 @@
+"""Loads a lightweight subset of the ICU library for Unicode processing."""
+
+load("//third_party:repo.bzl", "third_party_http_archive")
+
+def repo():
+ third_party_http_archive(
+ name = "icu",
+ strip_prefix = "icu-release-62-1",
+ sha256 = "e15ffd84606323cbad5515bf9ecdf8061cc3bf80fb883b9e6aa162e485aa9761",
+ urls = [
+ "https://mirror.bazel.build/github.com/unicode-org/icu/archive/release-62-1.tar.gz",
+ "https://github.com/unicode-org/icu/archive/release-62-1.tar.gz",
+ ],
+ build_file = "//third_party/icu:BUILD.bazel",
+ )
diff --git a/third_party/mkl/BUILD b/third_party/mkl/BUILD
index efff7fd51b..15a3e5cfa7 100644
--- a/third_party/mkl/BUILD
+++ b/third_party/mkl/BUILD
@@ -1,26 +1,26 @@
licenses(["notice"]) # 3-Clause BSD
config_setting(
- name = "using_mkl",
+ name = "build_with_mkl",
define_values = {
- "using_mkl": "true",
+ "build_with_mkl": "true",
},
visibility = ["//visibility:public"],
)
config_setting(
- name = "using_mkl_ml_only",
+ name = "build_with_mkl_ml_only",
define_values = {
- "using_mkl": "true",
- "using_mkl_ml_only": "true",
+ "build_with_mkl": "true",
+ "build_with_mkl_ml_only": "true",
},
visibility = ["//visibility:public"],
)
config_setting(
- name = "using_mkl_lnx_x64",
+ name = "build_with_mkl_lnx_x64",
define_values = {
- "using_mkl": "true",
+ "build_with_mkl": "true",
},
values = {
"cpu": "k8",
@@ -28,6 +28,15 @@ config_setting(
visibility = ["//visibility:public"],
)
+config_setting(
+ name = "enable_mkl",
+ define_values = {
+ "enable_mkl": "true",
+ "build_with_mkl": "true",
+ },
+ visibility = ["//visibility:public"],
+)
+
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
diff --git a/third_party/mkl/build_defs.bzl b/third_party/mkl/build_defs.bzl
index b645c0fc5c..10c2d90c84 100644
--- a/third_party/mkl/build_defs.bzl
+++ b/third_party/mkl/build_defs.bzl
@@ -1,9 +1,11 @@
# -*- Python -*-
"""Skylark macros for MKL.
-if_mkl is a conditional to check if MKL is enabled or not.
-if_mkl_ml is a conditional to check if MKL-ML is enabled.
+
+if_mkl is a conditional to check if we are building with MKL.
+if_mkl_ml is a conditional to check if we are building with MKL-ML.
if_mkl_ml_only is a conditional to check for MKL-ML-only (no MKL-DNN) mode.
if_mkl_lnx_x64 is a conditional to check for MKL
+if_enable_mkl is a conditional to check if building with MKL and MKL is enabled.
mkl_repository is a repository rule for creating MKL repository rule that can
be pointed to either a local folder, or download it from the internet.
@@ -24,7 +26,7 @@ def if_mkl(if_true, if_false = []):
a select evaluating to either if_true or if_false as appropriate.
"""
return select({
- str(Label("//third_party/mkl:using_mkl")): if_true,
+ str(Label("//third_party/mkl:build_with_mkl")): if_true,
"//conditions:default": if_false,
})
@@ -40,8 +42,8 @@ def if_mkl_ml(if_true, if_false = []):
a select evaluating to either if_true or if_false as appropriate.
"""
return select({
- str(Label("//third_party/mkl_dnn:using_mkl_dnn_only")): if_false,
- str(Label("//third_party/mkl:using_mkl")): if_true,
+ str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_only")): if_false,
+ str(Label("//third_party/mkl:build_with_mkl")): if_true,
"//conditions:default": if_false,
})
@@ -56,12 +58,12 @@ def if_mkl_ml_only(if_true, if_false = []):
a select evaluating to either if_true or if_false as appropriate.
"""
return select({
- str(Label("//third_party/mkl:using_mkl_ml_only")): if_true,
+ str(Label("//third_party/mkl:build_with_mkl_ml_only")): if_true,
"//conditions:default": if_false,
})
def if_mkl_lnx_x64(if_true, if_false = []):
- """Shorthand to select() on if MKL is on and the target is Linux x86-64.
+ """Shorthand to select() if building with MKL and the target is Linux x86-64.
Args:
if_true: expression to evaluate if building with MKL is enabled and the
@@ -73,7 +75,24 @@ def if_mkl_lnx_x64(if_true, if_false = []):
a select evaluating to either if_true or if_false as appropriate.
"""
return select({
- str(Label("//third_party/mkl:using_mkl_lnx_x64")): if_true,
+ str(Label("//third_party/mkl:build_with_mkl_lnx_x64")): if_true,
+ "//conditions:default": if_false,
+ })
+
+def if_enable_mkl(if_true, if_false = []):
+ """Shorthand to select() if we are building with MKL and MKL is enabled.
+
+ This is only effective when built with MKL.
+
+ Args:
+ if_true: expression to evaluate if building with MKL and MKL is enabled
+ if_false: expression to evaluate if building without MKL or MKL is not enabled.
+
+ Returns:
+ A select evaluating to either if_true or if_false as appropriate.
+ """
+ return select({
+ str(Label("//third_party/mkl:enable_mkl")): if_true,
"//conditions:default": if_false,
})
@@ -87,9 +106,9 @@ def mkl_deps():
inclusion in the deps attribute of rules.
"""
return select({
- str(Label("//third_party/mkl_dnn:using_mkl_dnn_only")): ["@mkl_dnn"],
- str(Label("//third_party/mkl:using_mkl_ml_only")): ["//third_party/mkl:intel_binary_blob"],
- str(Label("//third_party/mkl:using_mkl")): [
+ str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_only")): ["@mkl_dnn"],
+ str(Label("//third_party/mkl:build_with_mkl_ml_only")): ["//third_party/mkl:intel_binary_blob"],
+ str(Label("//third_party/mkl:build_with_mkl")): [
"//third_party/mkl:intel_binary_blob",
"@mkl_dnn",
],
diff --git a/third_party/mkl_dnn/BUILD b/third_party/mkl_dnn/BUILD
index 3e567fa9fc..58ecda55e6 100644
--- a/third_party/mkl_dnn/BUILD
+++ b/third_party/mkl_dnn/BUILD
@@ -3,10 +3,10 @@ licenses(["notice"])
exports_files(["LICENSE"])
config_setting(
- name = "using_mkl_dnn_only",
+ name = "build_with_mkl_dnn_only",
define_values = {
- "using_mkl": "true",
- "using_mkl_dnn_only": "true",
+ "build_with_mkl": "true",
+ "build_with_mkl_dnn_only": "true",
},
visibility = ["//visibility:public"],
)
diff --git a/third_party/mkl_dnn/build_defs.bzl b/third_party/mkl_dnn/build_defs.bzl
index 7ce2a7d9b0..6388f31971 100644
--- a/third_party/mkl_dnn/build_defs.bzl
+++ b/third_party/mkl_dnn/build_defs.bzl
@@ -8,6 +8,6 @@ def if_mkl_open_source_only(if_true, if_false = []):
"""
return select({
- str(Label("//third_party/mkl_dnn:using_mkl_dnn_only")): if_true,
+ str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_only")): if_true,
"//conditions:default": if_false,
})
diff --git a/third_party/ngraph/ngraph.BUILD b/third_party/ngraph/ngraph.BUILD
index 31aa3cee51..6602a480af 100644
--- a/third_party/ngraph/ngraph.BUILD
+++ b/third_party/ngraph/ngraph.BUILD
@@ -3,6 +3,121 @@ licenses(["notice"]) # 3-Clause BSD
exports_files(["LICENSE"])
cc_library(
+ name = "ngraph_headers",
+ hdrs = glob(["src/ngraph/**/*.hpp"]),
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "ngraph_cpu_backend",
+ srcs = [
+ "src/ngraph/runtime/cpu/builder/add.cpp",
+ "src/ngraph/runtime/cpu/builder/allreduce.cpp",
+ "src/ngraph/runtime/cpu/builder/argmax.cpp",
+ "src/ngraph/runtime/cpu/builder/argmin.cpp",
+ "src/ngraph/runtime/cpu/builder/avg_pool.cpp",
+ "src/ngraph/runtime/cpu/builder/batch_norm.cpp",
+ "src/ngraph/runtime/cpu/builder/bounded_relu.cpp",
+ "src/ngraph/runtime/cpu/builder/broadcast.cpp",
+ "src/ngraph/runtime/cpu/builder/concat.cpp",
+ "src/ngraph/runtime/cpu/builder/convert.cpp",
+ "src/ngraph/runtime/cpu/builder/convert_layout.cpp",
+ "src/ngraph/runtime/cpu/builder/convolution.cpp",
+ "src/ngraph/runtime/cpu/builder/dot.cpp",
+ "src/ngraph/runtime/cpu/builder/function_call.cpp",
+ "src/ngraph/runtime/cpu/builder/lrn.cpp",
+ "src/ngraph/runtime/cpu/builder/lstm.cpp",
+ "src/ngraph/runtime/cpu/builder/matmul_bias.cpp",
+ "src/ngraph/runtime/cpu/builder/max.cpp",
+ "src/ngraph/runtime/cpu/builder/max_pool.cpp",
+ "src/ngraph/runtime/cpu/builder/min.cpp",
+ "src/ngraph/runtime/cpu/builder/one_hot.cpp",
+ "src/ngraph/runtime/cpu/builder/pad.cpp",
+ "src/ngraph/runtime/cpu/builder/product.cpp",
+ "src/ngraph/runtime/cpu/builder/quantize.cpp",
+ "src/ngraph/runtime/cpu/builder/quantized_avg_pool.cpp",
+ "src/ngraph/runtime/cpu/builder/quantized_max_pool.cpp",
+ "src/ngraph/runtime/cpu/builder/reduce_function.cpp",
+ "src/ngraph/runtime/cpu/builder/reduce_function_window.cpp",
+ "src/ngraph/runtime/cpu/builder/relu.cpp",
+ "src/ngraph/runtime/cpu/builder/replace_slice.cpp",
+ "src/ngraph/runtime/cpu/builder/reshape.cpp",
+ "src/ngraph/runtime/cpu/builder/reverse.cpp",
+ "src/ngraph/runtime/cpu/builder/reverse_sequence.cpp",
+ "src/ngraph/runtime/cpu/builder/rnn.cpp",
+ "src/ngraph/runtime/cpu/builder/select.cpp",
+ "src/ngraph/runtime/cpu/builder/select_and_scatter.cpp",
+ "src/ngraph/runtime/cpu/builder/sigmoid.cpp",
+ "src/ngraph/runtime/cpu/builder/slice.cpp",
+ "src/ngraph/runtime/cpu/builder/softmax.cpp",
+ "src/ngraph/runtime/cpu/builder/sum.cpp",
+ "src/ngraph/runtime/cpu/builder/topk.cpp",
+ "src/ngraph/runtime/cpu/cpu_backend.cpp",
+ "src/ngraph/runtime/cpu/cpu_builder.cpp",
+ "src/ngraph/runtime/cpu/cpu_call_frame.cpp",
+ "src/ngraph/runtime/cpu/cpu_external_function.cpp",
+ "src/ngraph/runtime/cpu/cpu_kernels.cpp",
+ "src/ngraph/runtime/cpu/cpu_layout_descriptor.cpp",
+ "src/ngraph/runtime/cpu/cpu_tensor_view.cpp",
+ "src/ngraph/runtime/cpu/cpu_tensor_view_wrapper.cpp",
+ "src/ngraph/runtime/cpu/cpu_tracing.cpp",
+ "src/ngraph/runtime/cpu/kernel/eigen_thread_pool.cpp",
+ "src/ngraph/runtime/cpu/kernel/pad.cpp",
+ "src/ngraph/runtime/cpu/kernel/reduce_max.cpp",
+ "src/ngraph/runtime/cpu/kernel/reduce_sum.cpp",
+ "src/ngraph/runtime/cpu/kernel/reshape.cpp",
+ "src/ngraph/runtime/cpu/mkldnn_emitter.cpp",
+ "src/ngraph/runtime/cpu/mkldnn_invoke.cpp",
+ "src/ngraph/runtime/cpu/mkldnn_utils.cpp",
+ "src/ngraph/runtime/cpu/op/batch_dot.cpp",
+ "src/ngraph/runtime/cpu/op/batch_norm_relu.cpp",
+ "src/ngraph/runtime/cpu/op/bounded_relu.cpp",
+ "src/ngraph/runtime/cpu/op/conv_add.cpp",
+ "src/ngraph/runtime/cpu/op/conv_bias.cpp",
+ "src/ngraph/runtime/cpu/op/conv_relu.cpp",
+ "src/ngraph/runtime/cpu/op/convert_layout.cpp",
+ "src/ngraph/runtime/cpu/op/dequantize.cpp",
+ "src/ngraph/runtime/cpu/op/group_conv.cpp",
+ "src/ngraph/runtime/cpu/op/loop_kernel.cpp",
+ "src/ngraph/runtime/cpu/op/lstm.cpp",
+ "src/ngraph/runtime/cpu/op/matmul_bias.cpp",
+ "src/ngraph/runtime/cpu/op/max_pool_with_indices.cpp",
+ "src/ngraph/runtime/cpu/op/quantize.cpp",
+ "src/ngraph/runtime/cpu/op/quantized_avg_pool.cpp",
+ "src/ngraph/runtime/cpu/op/quantized_max_pool.cpp",
+ "src/ngraph/runtime/cpu/op/rnn.cpp",
+ "src/ngraph/runtime/cpu/op/sigmoid_mul.cpp",
+ "src/ngraph/runtime/cpu/pass/cpu_assignment.cpp",
+ "src/ngraph/runtime/cpu/pass/cpu_collapse_dims.cpp",
+ "src/ngraph/runtime/cpu/pass/cpu_concat_inputs.cpp",
+ "src/ngraph/runtime/cpu/pass/cpu_fusion.cpp",
+ "src/ngraph/runtime/cpu/pass/cpu_layout.cpp",
+ "src/ngraph/runtime/cpu/pass/cpu_loop_kernel_fusion.cpp",
+ "src/ngraph/runtime/cpu/pass/cpu_mat_fusion.cpp",
+ "src/ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.cpp",
+ "src/ngraph/runtime/cpu/pass/cpu_rnn_fusion.cpp",
+ "src/ngraph/runtime/cpu/pass/cpu_workspace_insertion.cpp",
+ ],
+ hdrs = glob(["src/ngraph/runtime/cpu/**/*.hpp"]) + glob([]),
+ deps = [
+ ":ngraph_headers",
+ "@eigen_archive//:eigen",
+ "@nlohmann_json_lib",
+ "@tbb",
+ "@mkl_dnn//:mkl_dnn",
+ ],
+ copts = [
+ "-I external/ngraph/src",
+ "-I external/nlohmann_json_lib/include/",
+ '-D SHARED_LIB_EXT=\\".so\\"',
+ '-D NGRAPH_VERSION=\\"0.8.1\\"',
+ "-D NGRAPH_DEX_ONLY",
+ ],
+ visibility = ["//visibility:public"],
+ alwayslink = 1,
+)
+
+cc_library(
name = "ngraph_core",
srcs = glob([
"src/ngraph/*.cpp",
@@ -18,11 +133,10 @@ cc_library(
"src/ngraph/pass/*.hpp",
"src/ngraph/runtime/*.cpp",
"src/ngraph/type/*.cpp",
- "src/ngraph/runtime/interpreter/*.cpp",
- "src/ngraph/runtime/interpreter/*.hpp",
]),
- hdrs = glob(["src/ngraph/**/*.hpp"]),
deps = [
+ ":ngraph_headers",
+ ":ngraph_cpu_backend",
"@eigen_archive//:eigen",
"@nlohmann_json_lib",
],
@@ -30,7 +144,7 @@ cc_library(
"-I external/ngraph/src",
"-I external/nlohmann_json_lib/include/",
'-D SHARED_LIB_EXT=\\".so\\"',
- '-D NGRAPH_VERSION=\\"0.5.0\\"',
+ '-D NGRAPH_VERSION=\\"0.8.1\\"',
],
visibility = ["//visibility:public"],
alwayslink = 1,
diff --git a/third_party/ngraph/ngraph_tf.BUILD b/third_party/ngraph/ngraph_tf.BUILD
index 4d96ccf2f2..dbedca0a03 100644
--- a/third_party/ngraph/ngraph_tf.BUILD
+++ b/third_party/ngraph/ngraph_tf.BUILD
@@ -8,46 +8,44 @@ load(
)
cc_library(
- name = "ngraph_libs_linux",
- srcs = [
- "lib/libiomp5.so",
- "lib/libmklml_intel.so",
- ],
- visibility = ["//visibility:public"],
-)
-
-cc_library(
name = "ngraph_tf",
srcs = [
- "src/ngraph_builder.h",
+ "src/ngraph_api.cc",
+ "src/ngraph_api.h",
+ "src/ngraph_assign_clusters.cc",
+ "src/ngraph_assign_clusters.h",
"src/ngraph_builder.cc",
- "src/ngraph_cluster.h",
- "src/ngraph_cluster.cc",
- "src/ngraph_cluster_manager.h",
+ "src/ngraph_builder.h",
+ "src/ngraph_capture_variables.cc",
+ "src/ngraph_capture_variables.h",
"src/ngraph_cluster_manager.cc",
- "src/ngraph_confirm_pass.cc",
- "src/ngraph_device.cc",
+ "src/ngraph_cluster_manager.h",
+ "src/ngraph_conversions.h",
+ "src/ngraph_deassign_clusters.cc",
+ "src/ngraph_deassign_clusters.h",
+ "src/ngraph_encapsulate_clusters.cc",
+ "src/ngraph_encapsulate_clusters.h",
"src/ngraph_encapsulate_op.cc",
- "src/ngraph_encapsulate_pass.cc",
- "src/ngraph_freshness_tracker.h",
"src/ngraph_freshness_tracker.cc",
- "src/ngraph_graph_rewrite_passes.cc",
- "src/ngraph_liberate_pass.cc",
- "src/ngraph_op_kernels.cc",
- "src/ngraph_stub_ops.cc",
- "src/ngraph_utils.h",
+ "src/ngraph_freshness_tracker.h",
+ "src/ngraph_mark_for_clustering.cc",
+ "src/ngraph_mark_for_clustering.h",
+ "src/ngraph_rewrite_for_tracking.cc",
+ "src/ngraph_rewrite_for_tracking.h",
+ "src/ngraph_rewrite_pass.cc",
+ "src/ngraph_tracked_variable.cc",
"src/ngraph_utils.cc",
- "src/ngraph_send_recv_ops.cc",
- "src/ngraph_variable_ops.cc",
+ "src/ngraph_utils.h",
+ "src/ngraph_version_utils.h",
+ "src/tf_deadness_analysis.cc",
+ "src/tf_deadness_analysis.h",
"src/tf_graphcycles.cc",
+ "src/tf_graphcycles.h",
"logging/ngraph_log.h",
"logging/ngraph_log.cc",
"logging/tf_graph_writer.h",
"logging/tf_graph_writer.cc",
],
- hdrs = [
- "src/tf_graphcycles.h",
- ],
deps = [
"@org_tensorflow//tensorflow/core:protos_all_proto_text",
"@org_tensorflow//tensorflow/core:framework_headers_lib",
@@ -58,7 +56,6 @@ cc_library(
"-I external/ngraph_tf/src",
"-I external/ngraph_tf/logging",
"-I external/ngraph/src",
- "-D NGRAPH_EMBEDDED_IN_TENSORFLOW=1",
],
alwayslink = 1,
visibility = ["//visibility:public"],
@@ -68,8 +65,19 @@ tf_cc_test(
name = "ngraph_tf_tests",
size = "small",
srcs = [
- "test/tf_exec.cpp",
+ "test/conversions.cpp",
+ "test/graph_rewrites/assign_clusters.cc",
+ "test/graph_rewrites/deadness_test.cc",
"test/main.cpp",
+ "test/opexecuter.cpp",
+ "test/opexecuter.h",
+ "test/padding.cpp",
+ "test/test_array_ops.cpp",
+ "test/test_math_ops.cpp",
+ "test/test_nn_ops.cpp",
+ "test/test_utilities.cpp",
+ "test/test_utilities.h",
+ "test/tf_exec.cpp",
],
deps = [
":ngraph_tf",
@@ -80,7 +88,6 @@ tf_cc_test(
],
extra_copts = [
"-fexceptions ",
- "-D NGRAPH_EMBEDDED_IN_TENSORFLOW=1",
"-I external/ngraph_tf/src",
"-I external/ngraph_tf/logging",
"-I external/ngraph/src",
diff --git a/third_party/ngraph/tbb.BUILD b/third_party/ngraph/tbb.BUILD
new file mode 100644
index 0000000000..04e6544ffb
--- /dev/null
+++ b/third_party/ngraph/tbb.BUILD
@@ -0,0 +1,63 @@
+licenses(["notice"]) # 3-Clause BSD
+
+exports_files(["LICENSE"])
+
+# Taken from: https://github.com/rnburn/satyr/blob/master/bazel/tbb.BUILD
+# License for this BUILD file: MIT
+# See: https://github.com/rnburn/satyr/blob/master/LICENSE
+#
+# License for TBB: Apache 2.0
+# See: https://github.com/01org/tbb/blob/tbb_2018/LICENSE
+
+genrule(
+ name = "build_tbb",
+ srcs = glob(["**"]) + [
+ "@local_config_cc//:toolchain",
+ ],
+ cmd = """
+ set -e
+ WORK_DIR=$$PWD
+ DEST_DIR=$$PWD/$(@D)
+ export PATH=$$(dirname $(AR)):$$PATH
+ export CXXFLAGS=$(CC_FLAGS)
+ export NM=$(NM)
+ export AR=$(AR)
+ cd $$(dirname $(location :Makefile))
+
+ #TBB's build needs some help to figure out what compiler it's using
+ if $$CXX --version | grep clang &> /dev/null; then
+ COMPILER_OPT="compiler=clang"
+ else
+ COMPILER_OPT="compiler=gcc"
+
+ # # Workaround for TBB bug
+ # # See https://github.com/01org/tbb/issues/59
+ # CXXFLAGS="$$CXXFLAGS -flifetime-dse=1"
+ fi
+
+ # uses extra_inc=big_iron.inc to specify that static libraries are
+ # built. See https://software.intel.com/en-us/forums/intel-threading-building-blocks/topic/297792
+ make tbb_build_prefix="build" \
+ extra_inc=big_iron.inc \
+ $$COMPILER_OPT; \
+
+ echo cp build/build_{release,debug}/*.a $$DEST_DIR
+ cp build/build_{release,debug}/*.a $$DEST_DIR
+ cd $$WORK_DIR
+ """,
+ outs = [
+ "libtbb.a",
+ "libtbbmalloc.a",
+ ],
+)
+
+cc_library(
+ name = "tbb",
+ hdrs = glob([
+ "include/serial/**",
+ "include/tbb/**/**",
+ ]),
+ srcs = ["libtbb.a"],
+ includes = ["include"],
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/repo.bzl b/third_party/repo.bzl
index 7d1aa5dce9..6e30618d39 100644
--- a/third_party/repo.bzl
+++ b/third_party/repo.bzl
@@ -119,6 +119,10 @@ def _tf_http_archive(ctx):
"%prefix%": ".." if _repos_are_siblings() else "external",
}, False)
+ if use_syslib:
+ for internal_src, external_dest in ctx.attr.system_link_files.items():
+ ctx.symlink(Label(internal_src), ctx.path(external_dest))
+
tf_http_archive = repository_rule(
implementation = _tf_http_archive,
attrs = {
@@ -130,6 +134,7 @@ tf_http_archive = repository_rule(
"patch_file": attr.label(),
"build_file": attr.label(),
"system_build_file": attr.label(),
+ "system_link_files": attr.string_dict(),
},
environ = [
"TF_SYSTEM_LIBS",
@@ -180,7 +185,16 @@ def _third_party_http_archive(ctx):
_apply_patch(ctx, ctx.attr.patch_file)
ctx.symlink(Label(ctx.attr.build_file), buildfile_path)
+ link_dict = dict()
+ if use_syslib:
+ link_dict.update(ctx.attr.system_link_files)
+
for internal_src, external_dest in ctx.attr.link_files.items():
+ # if syslib and link exists in both, use the system one
+ if external_dest not in link_dict.values():
+ link_dict[internal_src] = external_dest
+
+ for internal_src, external_dest in link_dict.items():
ctx.symlink(Label(internal_src), ctx.path(external_dest))
# Downloads and creates Bazel repos for dependencies.
@@ -201,6 +215,7 @@ third_party_http_archive = repository_rule(
"system_build_file": attr.string(mandatory = False),
"patch_file": attr.label(),
"link_files": attr.string_dict(),
+ "system_link_files": attr.string_dict(),
},
environ = [
"TF_SYSTEM_LIBS",
diff --git a/third_party/systemlibs/absl_py.BUILD b/third_party/systemlibs/absl_py.BUILD
new file mode 100644
index 0000000000..fe756e1be2
--- /dev/null
+++ b/third_party/systemlibs/absl_py.BUILD
@@ -0,0 +1 @@
+licenses(["notice"]) # Apache 2.0
diff --git a/third_party/systemlibs/absl_py.absl.flags.BUILD b/third_party/systemlibs/absl_py.absl.flags.BUILD
new file mode 100644
index 0000000000..95ec92b887
--- /dev/null
+++ b/third_party/systemlibs/absl_py.absl.flags.BUILD
@@ -0,0 +1,11 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//visibility:public"])
+
+filegroup(
+ name = "LICENSE",
+)
+
+py_library(
+ name = "flags",
+)
diff --git a/third_party/systemlibs/absl_py.absl.testing.BUILD b/third_party/systemlibs/absl_py.absl.testing.BUILD
new file mode 100644
index 0000000000..c1b794c1e9
--- /dev/null
+++ b/third_party/systemlibs/absl_py.absl.testing.BUILD
@@ -0,0 +1,7 @@
+licenses(["notice"]) # Apache 2.0
+
+py_library(
+ name = "parameterized",
+ testonly = 1,
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/systemlibs/boringssl.BUILD b/third_party/systemlibs/boringssl.BUILD
new file mode 100644
index 0000000000..bc4c533403
--- /dev/null
+++ b/third_party/systemlibs/boringssl.BUILD
@@ -0,0 +1,21 @@
+licenses(["notice"])
+
+filegroup(
+ name = "LICENSE",
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "crypto",
+ linkopts = ["-lcrypto"],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "ssl",
+ linkopts = ["-lssl"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":crypto",
+ ],
+)
diff --git a/third_party/systemlibs/double_conversion.BUILD b/third_party/systemlibs/double_conversion.BUILD
new file mode 100644
index 0000000000..568460181a
--- /dev/null
+++ b/third_party/systemlibs/double_conversion.BUILD
@@ -0,0 +1,12 @@
+licenses(["notice"])
+
+filegroup(
+ name = "LICENSE",
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "double-conversion",
+ linkopts = ["-ldouble-conversion"],
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/systemlibs/gast.BUILD b/third_party/systemlibs/gast.BUILD
new file mode 100644
index 0000000000..c6e1d0c4e0
--- /dev/null
+++ b/third_party/systemlibs/gast.BUILD
@@ -0,0 +1,12 @@
+licenses(["notice"]) # BSD 3-clause
+
+filegroup(
+ name = "PKG-INFO",
+ visibility = ["//visibility:public"],
+)
+
+py_library(
+ name = "gast",
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/systemlibs/google_cloud_cpp.BUILD b/third_party/systemlibs/google_cloud_cpp.BUILD
new file mode 100644
index 0000000000..cbe6e10ba5
--- /dev/null
+++ b/third_party/systemlibs/google_cloud_cpp.BUILD
@@ -0,0 +1,6 @@
+licenses(["notice"]) # Apache 2.0
+
+filegroup(
+ name = "LICENSE",
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD b/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
new file mode 100644
index 0000000000..b59d565390
--- /dev/null
+++ b/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
@@ -0,0 +1,7 @@
+licenses(["notice"]) # Apache 2.0
+
+cc_library(
+ name = "bigtable_client",
+ linkopts = ["-lbigtable_client"],
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/systemlibs/googleapis.BUILD b/third_party/systemlibs/googleapis.BUILD
new file mode 100644
index 0000000000..7687745df9
--- /dev/null
+++ b/third_party/systemlibs/googleapis.BUILD
@@ -0,0 +1,12 @@
+licenses(["notice"]) # Apache 2.0
+
+filegroup(
+ name = "LICENSE",
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "bigtable_protos",
+ linkopts = ["-lbigtable_protos"],
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/systemlibs/jsoncpp.BUILD b/third_party/systemlibs/jsoncpp.BUILD
index cf91917cfb..526fd0c418 100644
--- a/third_party/systemlibs/jsoncpp.BUILD
+++ b/third_party/systemlibs/jsoncpp.BUILD
@@ -23,7 +23,7 @@ genrule(
cmd = """
for i in $(OUTS); do
i=$${i##*/}
- ln -vsf /usr/include/jsoncpp/json/$$i $(@D)/include/json/$$i
+ ln -sf $(INCLUDEDIR)/jsoncpp/json/$$i $(@D)/include/json/$$i
done
""",
)
diff --git a/third_party/systemlibs/syslibs_configure.bzl b/third_party/systemlibs/syslibs_configure.bzl
index 8b09c9ac1f..8b0ab39eaf 100644
--- a/third_party/systemlibs/syslibs_configure.bzl
+++ b/third_party/systemlibs/syslibs_configure.bzl
@@ -10,11 +10,17 @@
_TF_SYSTEM_LIBS = "TF_SYSTEM_LIBS"
VALID_LIBS = [
+ "absl_py",
"astor_archive",
+ "boringssl",
+ "com_github_googleapis_googleapis",
+ "com_github_googlecloudplatform_google_cloud_cpp",
"com_googlesource_code_re2",
"curl",
"cython",
+ "double_conversion",
"flatbuffers",
+ "gast_archive",
"gif_archive",
"grpc",
"jemalloc",
diff --git a/third_party/toolchains/BUILD b/third_party/toolchains/BUILD
index 4303751452..7256a7d96e 100644
--- a/third_party/toolchains/BUILD
+++ b/third_party/toolchains/BUILD
@@ -32,6 +32,6 @@ platform(
remote_execution_properties = """
properties: {
name: "container-image"
- value:"docker://gcr.io/asci-toolchain/nosla-cuda9.0-cudnn7-ubuntu14.04@sha256:ae58329b961e7c17d89725bf8fd72dfbd5850f4f3313de58e0cafbf5b0343735"
+ value:"docker://gcr.io/asci-toolchain/nosla-cuda9.0-cudnn7-ubuntu14.04@sha256:06b585f42eed3b2030e9566b8f88f48d7472fa0f47e59765bc115376c8801bdf"
}""",
)
diff --git a/tools/bazel.rc b/tools/bazel.rc
index 601e07ffdd..3734fab715 100644
--- a/tools/bazel.rc
+++ b/tools/bazel.rc
@@ -24,12 +24,13 @@ build --define framework_shared_object=true
# Please note that MKL on MacOS or windows is still not supported.
# If you would like to use a local MKL instead of downloading, please set the
# environment variable "TF_MKL_ROOT" every time before build.
-build:mkl --define=using_mkl=true
+build:mkl --define=build_with_mkl=true --define=enable_mkl=true
build:mkl -c opt
# This config option is used to enable MKL-DNN open source library only,
# without depending on MKL binary version.
-build:mkl_open_source_only --define=using_mkl_dnn_only=true
+build:mkl_open_source_only --define=build_with_mkl_dnn_only=true
+build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true
build:download_clang --crosstool_top=@local_config_download_clang//:toolchain
build:download_clang --define=using_clang=true
@@ -42,6 +43,9 @@ build:download_clang_use_lld --linkopt='-fuse-ld=lld'
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true
+build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
+build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
+
build:cuda_clang --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda_clang --define=using_cuda=true --define=using_cuda_clang=true --define=using_clang=true
@@ -57,6 +61,11 @@ build:sycl_asan --define=using_sycl=true --define=using_trisycl=false --copt -fn
build:sycl_trisycl --crosstool_top=@local_config_sycl//crosstool:toolchain
build:sycl_trisycl --define=using_sycl=true --define=using_trisycl=true
+# Options extracted from configure script
+build:gdr --define=with_gdr_support=true
+build:ngraph --define=with_ngraph_support=true
+build:verbs --define=with_verbs_support=true
+
build --define=use_fast_cpp_protos=true
build --define=allow_oversize_protos=true
build --define=grpc_no_ares=true
@@ -65,5 +74,13 @@ build --spawn_strategy=standalone
build --genrule_strategy=standalone
build -c opt
+# Other build flags.
+build --define=grpc_no_ares=true
+
# Modular TF build options
build:dynamic_kernels --define=dynamic_loaded_kernels=true
+
+# Default paths for TF_SYSTEM_LIBS
+build --define=PREFIX=/usr
+build --define=LIBDIR=$(PREFIX)/lib
+build --define=INCLUDEDIR=$(PREFIX)/include